From 7b04d813739c0b5595295dffdc86cc41108db2d3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 6 May 2023 16:25:36 +0200 Subject: Update --- training/functional.py | 20 +++--- training/sampler.py | 154 ++++++++++++++++++++++++++++++++++++++++++++++ training/strategy/lora.py | 10 +-- 3 files changed, 171 insertions(+), 13 deletions(-) create mode 100644 training/sampler.py (limited to 'training') diff --git a/training/functional.py b/training/functional.py index e7e1eb3..eae5681 100644 --- a/training/functional.py +++ b/training/functional.py @@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from models.convnext.discriminator import ConvNeXtDiscriminator from training.util import AverageMeter +from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler from util.slerp import slerp @@ -318,6 +319,7 @@ def get_original( def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, + schedule_sampler: ScheduleSampler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, guidance_scale: float, @@ -362,14 +364,7 @@ def loss_step( new_noise = noise + input_pertubation * torch.randn_like(noise) # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (bsz,), - generator=generator, - device=latents.device, - ) - timesteps = timesteps.long() + timesteps, weights = schedule_sampler.sample(bsz, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -443,6 +438,10 @@ def loss_step( ) loss = loss * mse_loss_weights + if isinstance(schedule_sampler, LossAwareSampler): + schedule_sampler.update_with_all_losses(timesteps, loss.detach()) + + loss = loss * weights loss = loss.mean() return loss, acc, bsz @@ -694,6 +693,7 @@ def train( offset_noise_strength: float = 0.01, input_pertubation: float = 0.1, disc: Optional[ConvNeXtDiscriminator] = None, + schedule_sampler: Optional[ScheduleSampler] = None, min_snr_gamma: int = 5, avg_loss: AverageMeter = AverageMeter(), avg_acc: AverageMeter = AverageMeter(), @@ -725,10 +725,14 @@ def train( **kwargs, ) + if schedule_sampler is None: + schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) + loss_step_ = partial( loss_step, vae, noise_scheduler, + schedule_sampler, unet, text_encoder, guidance_scale, diff --git a/training/sampler.py b/training/sampler.py new file mode 100644 index 0000000..8afe255 --- /dev/null +++ b/training/sampler.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch +import torch.distributed as dist + + +def create_named_schedule_sampler(name, num_timesteps): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(num_timesteps) + elif name == "loss-second-moment": + return LossSecondMomentResampler(num_timesteps) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = torch.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = torch.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, num_timesteps): + self.num_timesteps = num_timesteps + self._weights = np.ones([num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + torch.tensor([0], dtype=torch.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001): + self.num_timesteps = num_timesteps + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [self.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 3f4dbbc..0c0f633 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -120,11 +120,11 @@ def lora_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) - for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): - text_encoder_.text_model.embeddings.save_embed( - ids, - checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" - ) + # for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + # text_encoder_.text_model.embeddings.save_embed( + # ids, + # checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + # ) if not pti_mode: lora_config = {} -- cgit v1.2.3-70-g09d2