diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 20 | ||||
-rw-r--r-- | training/sampler.py | 154 | ||||
-rw-r--r-- | training/strategy/lora.py | 10 |
3 files changed, 171 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py index e7e1eb3..eae5681 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings | |||
25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator |
27 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler | ||
28 | from util.slerp import slerp | 29 | from util.slerp import slerp |
29 | 30 | ||
30 | 31 | ||
@@ -318,6 +319,7 @@ def get_original( | |||
318 | def loss_step( | 319 | def loss_step( |
319 | vae: AutoencoderKL, | 320 | vae: AutoencoderKL, |
320 | noise_scheduler: SchedulerMixin, | 321 | noise_scheduler: SchedulerMixin, |
322 | schedule_sampler: ScheduleSampler, | ||
321 | unet: UNet2DConditionModel, | 323 | unet: UNet2DConditionModel, |
322 | text_encoder: CLIPTextModel, | 324 | text_encoder: CLIPTextModel, |
323 | guidance_scale: float, | 325 | guidance_scale: float, |
@@ -362,14 +364,7 @@ def loss_step( | |||
362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | 364 | new_noise = noise + input_pertubation * torch.randn_like(noise) |
363 | 365 | ||
364 | # Sample a random timestep for each image | 366 | # Sample a random timestep for each image |
365 | timesteps = torch.randint( | 367 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) |
366 | 0, | ||
367 | noise_scheduler.config.num_train_timesteps, | ||
368 | (bsz,), | ||
369 | generator=generator, | ||
370 | device=latents.device, | ||
371 | ) | ||
372 | timesteps = timesteps.long() | ||
373 | 368 | ||
374 | # Add noise to the latents according to the noise magnitude at each timestep | 369 | # Add noise to the latents according to the noise magnitude at each timestep |
375 | # (this is the forward diffusion process) | 370 | # (this is the forward diffusion process) |
@@ -443,6 +438,10 @@ def loss_step( | |||
443 | ) | 438 | ) |
444 | loss = loss * mse_loss_weights | 439 | loss = loss * mse_loss_weights |
445 | 440 | ||
441 | if isinstance(schedule_sampler, LossAwareSampler): | ||
442 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) | ||
443 | |||
444 | loss = loss * weights | ||
446 | loss = loss.mean() | 445 | loss = loss.mean() |
447 | 446 | ||
448 | return loss, acc, bsz | 447 | return loss, acc, bsz |
@@ -694,6 +693,7 @@ def train( | |||
694 | offset_noise_strength: float = 0.01, | 693 | offset_noise_strength: float = 0.01, |
695 | input_pertubation: float = 0.1, | 694 | input_pertubation: float = 0.1, |
696 | disc: Optional[ConvNeXtDiscriminator] = None, | 695 | disc: Optional[ConvNeXtDiscriminator] = None, |
696 | schedule_sampler: Optional[ScheduleSampler] = None, | ||
697 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, |
698 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), |
699 | avg_acc: AverageMeter = AverageMeter(), | 699 | avg_acc: AverageMeter = AverageMeter(), |
@@ -725,10 +725,14 @@ def train( | |||
725 | **kwargs, | 725 | **kwargs, |
726 | ) | 726 | ) |
727 | 727 | ||
728 | if schedule_sampler is None: | ||
729 | schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) | ||
730 | |||
728 | loss_step_ = partial( | 731 | loss_step_ = partial( |
729 | loss_step, | 732 | loss_step, |
730 | vae, | 733 | vae, |
731 | noise_scheduler, | 734 | noise_scheduler, |
735 | schedule_sampler, | ||
732 | unet, | 736 | unet, |
733 | text_encoder, | 737 | text_encoder, |
734 | guidance_scale, | 738 | guidance_scale, |
diff --git a/training/sampler.py b/training/sampler.py new file mode 100644 index 0000000..8afe255 --- /dev/null +++ b/training/sampler.py | |||
@@ -0,0 +1,154 @@ | |||
1 | from abc import ABC, abstractmethod | ||
2 | |||
3 | import numpy as np | ||
4 | import torch | ||
5 | import torch.distributed as dist | ||
6 | |||
7 | |||
8 | def create_named_schedule_sampler(name, num_timesteps): | ||
9 | """ | ||
10 | Create a ScheduleSampler from a library of pre-defined samplers. | ||
11 | |||
12 | :param name: the name of the sampler. | ||
13 | :param diffusion: the diffusion object to sample for. | ||
14 | """ | ||
15 | if name == "uniform": | ||
16 | return UniformSampler(num_timesteps) | ||
17 | elif name == "loss-second-moment": | ||
18 | return LossSecondMomentResampler(num_timesteps) | ||
19 | else: | ||
20 | raise NotImplementedError(f"unknown schedule sampler: {name}") | ||
21 | |||
22 | |||
23 | class ScheduleSampler(ABC): | ||
24 | """ | ||
25 | A distribution over timesteps in the diffusion process, intended to reduce | ||
26 | variance of the objective. | ||
27 | |||
28 | By default, samplers perform unbiased importance sampling, in which the | ||
29 | objective's mean is unchanged. | ||
30 | However, subclasses may override sample() to change how the resampled | ||
31 | terms are reweighted, allowing for actual changes in the objective. | ||
32 | """ | ||
33 | |||
34 | @abstractmethod | ||
35 | def weights(self): | ||
36 | """ | ||
37 | Get a numpy array of weights, one per diffusion step. | ||
38 | |||
39 | The weights needn't be normalized, but must be positive. | ||
40 | """ | ||
41 | |||
42 | def sample(self, batch_size, device): | ||
43 | """ | ||
44 | Importance-sample timesteps for a batch. | ||
45 | |||
46 | :param batch_size: the number of timesteps. | ||
47 | :param device: the torch device to save to. | ||
48 | :return: a tuple (timesteps, weights): | ||
49 | - timesteps: a tensor of timestep indices. | ||
50 | - weights: a tensor of weights to scale the resulting losses. | ||
51 | """ | ||
52 | w = self.weights() | ||
53 | p = w / np.sum(w) | ||
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) | ||
55 | indices = torch.from_numpy(indices_np).long().to(device) | ||
56 | weights_np = 1 / (len(p) * p[indices_np]) | ||
57 | weights = torch.from_numpy(weights_np).float().to(device) | ||
58 | return indices, weights | ||
59 | |||
60 | |||
61 | class UniformSampler(ScheduleSampler): | ||
62 | def __init__(self, num_timesteps): | ||
63 | self.num_timesteps = num_timesteps | ||
64 | self._weights = np.ones([num_timesteps]) | ||
65 | |||
66 | def weights(self): | ||
67 | return self._weights | ||
68 | |||
69 | |||
70 | class LossAwareSampler(ScheduleSampler): | ||
71 | def update_with_local_losses(self, local_ts, local_losses): | ||
72 | """ | ||
73 | Update the reweighting using losses from a model. | ||
74 | |||
75 | Call this method from each rank with a batch of timesteps and the | ||
76 | corresponding losses for each of those timesteps. | ||
77 | This method will perform synchronization to make sure all of the ranks | ||
78 | maintain the exact same reweighting. | ||
79 | |||
80 | :param local_ts: an integer Tensor of timesteps. | ||
81 | :param local_losses: a 1D Tensor of losses. | ||
82 | """ | ||
83 | batch_sizes = [ | ||
84 | torch.tensor([0], dtype=torch.int32, device=local_ts.device) | ||
85 | for _ in range(dist.get_world_size()) | ||
86 | ] | ||
87 | dist.all_gather( | ||
88 | batch_sizes, | ||
89 | torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), | ||
90 | ) | ||
91 | |||
92 | # Pad all_gather batches to be the maximum batch size. | ||
93 | batch_sizes = [x.item() for x in batch_sizes] | ||
94 | max_bs = max(batch_sizes) | ||
95 | |||
96 | timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] | ||
97 | loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] | ||
98 | dist.all_gather(timestep_batches, local_ts) | ||
99 | dist.all_gather(loss_batches, local_losses) | ||
100 | timesteps = [ | ||
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] | ||
102 | ] | ||
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] | ||
104 | self.update_with_all_losses(timesteps, losses) | ||
105 | |||
106 | @abstractmethod | ||
107 | def update_with_all_losses(self, ts, losses): | ||
108 | """ | ||
109 | Update the reweighting using losses from a model. | ||
110 | |||
111 | Sub-classes should override this method to update the reweighting | ||
112 | using losses from the model. | ||
113 | |||
114 | This method directly updates the reweighting without synchronizing | ||
115 | between workers. It is called by update_with_local_losses from all | ||
116 | ranks with identical arguments. Thus, it should have deterministic | ||
117 | behavior to maintain state across workers. | ||
118 | |||
119 | :param ts: a list of int timesteps. | ||
120 | :param losses: a list of float losses, one per timestep. | ||
121 | """ | ||
122 | |||
123 | |||
124 | class LossSecondMomentResampler(LossAwareSampler): | ||
125 | def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001): | ||
126 | self.num_timesteps = num_timesteps | ||
127 | self.history_per_term = history_per_term | ||
128 | self.uniform_prob = uniform_prob | ||
129 | self._loss_history = np.zeros( | ||
130 | [self.num_timesteps, history_per_term], dtype=np.float64 | ||
131 | ) | ||
132 | self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) | ||
133 | |||
134 | def weights(self): | ||
135 | if not self._warmed_up(): | ||
136 | return np.ones([self.num_timesteps], dtype=np.float64) | ||
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) | ||
138 | weights /= np.sum(weights) | ||
139 | weights *= 1 - self.uniform_prob | ||
140 | weights += self.uniform_prob / len(weights) | ||
141 | return weights | ||
142 | |||
143 | def update_with_all_losses(self, ts, losses): | ||
144 | for t, loss in zip(ts, losses): | ||
145 | if self._loss_counts[t] == self.history_per_term: | ||
146 | # Shift out the oldest loss term. | ||
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] | ||
148 | self._loss_history[t, -1] = loss | ||
149 | else: | ||
150 | self._loss_history[t, self._loss_counts[t]] = loss | ||
151 | self._loss_counts[t] += 1 | ||
152 | |||
153 | def _warmed_up(self): | ||
154 | return (self._loss_counts == self.history_per_term).all() | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 3f4dbbc..0c0f633 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -120,11 +120,11 @@ def lora_strategy_callbacks( | |||
120 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 120 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
121 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 121 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
122 | 122 | ||
123 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 123 | # for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
124 | text_encoder_.text_model.embeddings.save_embed( | 124 | # text_encoder_.text_model.embeddings.save_embed( |
125 | ids, | 125 | # ids, |
126 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 126 | # checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
127 | ) | 127 | # ) |
128 | 128 | ||
129 | if not pti_mode: | 129 | if not pti_mode: |
130 | lora_config = {} | 130 | lora_config = {} |