diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
| commit | 7b04d813739c0b5595295dffdc86cc41108db2d3 (patch) | |
| tree | 8958b612f5d3d665866770ad553e1004aa4b6fb8 /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2 textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip | |
Update
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 20 | ||||
| -rw-r--r-- | training/sampler.py | 154 | ||||
| -rw-r--r-- | training/strategy/lora.py | 10 |
3 files changed, 171 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py index e7e1eb3..eae5681 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings | |||
| 25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator |
| 27 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
| 28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler | ||
| 28 | from util.slerp import slerp | 29 | from util.slerp import slerp |
| 29 | 30 | ||
| 30 | 31 | ||
| @@ -318,6 +319,7 @@ def get_original( | |||
| 318 | def loss_step( | 319 | def loss_step( |
| 319 | vae: AutoencoderKL, | 320 | vae: AutoencoderKL, |
| 320 | noise_scheduler: SchedulerMixin, | 321 | noise_scheduler: SchedulerMixin, |
| 322 | schedule_sampler: ScheduleSampler, | ||
| 321 | unet: UNet2DConditionModel, | 323 | unet: UNet2DConditionModel, |
| 322 | text_encoder: CLIPTextModel, | 324 | text_encoder: CLIPTextModel, |
| 323 | guidance_scale: float, | 325 | guidance_scale: float, |
| @@ -362,14 +364,7 @@ def loss_step( | |||
| 362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | 364 | new_noise = noise + input_pertubation * torch.randn_like(noise) |
| 363 | 365 | ||
| 364 | # Sample a random timestep for each image | 366 | # Sample a random timestep for each image |
| 365 | timesteps = torch.randint( | 367 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) |
| 366 | 0, | ||
| 367 | noise_scheduler.config.num_train_timesteps, | ||
| 368 | (bsz,), | ||
| 369 | generator=generator, | ||
| 370 | device=latents.device, | ||
| 371 | ) | ||
| 372 | timesteps = timesteps.long() | ||
| 373 | 368 | ||
| 374 | # Add noise to the latents according to the noise magnitude at each timestep | 369 | # Add noise to the latents according to the noise magnitude at each timestep |
| 375 | # (this is the forward diffusion process) | 370 | # (this is the forward diffusion process) |
| @@ -443,6 +438,10 @@ def loss_step( | |||
| 443 | ) | 438 | ) |
| 444 | loss = loss * mse_loss_weights | 439 | loss = loss * mse_loss_weights |
| 445 | 440 | ||
| 441 | if isinstance(schedule_sampler, LossAwareSampler): | ||
| 442 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) | ||
| 443 | |||
| 444 | loss = loss * weights | ||
| 446 | loss = loss.mean() | 445 | loss = loss.mean() |
| 447 | 446 | ||
| 448 | return loss, acc, bsz | 447 | return loss, acc, bsz |
| @@ -694,6 +693,7 @@ def train( | |||
| 694 | offset_noise_strength: float = 0.01, | 693 | offset_noise_strength: float = 0.01, |
| 695 | input_pertubation: float = 0.1, | 694 | input_pertubation: float = 0.1, |
| 696 | disc: Optional[ConvNeXtDiscriminator] = None, | 695 | disc: Optional[ConvNeXtDiscriminator] = None, |
| 696 | schedule_sampler: Optional[ScheduleSampler] = None, | ||
| 697 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, |
| 698 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), |
| 699 | avg_acc: AverageMeter = AverageMeter(), | 699 | avg_acc: AverageMeter = AverageMeter(), |
| @@ -725,10 +725,14 @@ def train( | |||
| 725 | **kwargs, | 725 | **kwargs, |
| 726 | ) | 726 | ) |
| 727 | 727 | ||
| 728 | if schedule_sampler is None: | ||
| 729 | schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) | ||
| 730 | |||
| 728 | loss_step_ = partial( | 731 | loss_step_ = partial( |
| 729 | loss_step, | 732 | loss_step, |
| 730 | vae, | 733 | vae, |
| 731 | noise_scheduler, | 734 | noise_scheduler, |
| 735 | schedule_sampler, | ||
| 732 | unet, | 736 | unet, |
| 733 | text_encoder, | 737 | text_encoder, |
| 734 | guidance_scale, | 738 | guidance_scale, |
diff --git a/training/sampler.py b/training/sampler.py new file mode 100644 index 0000000..8afe255 --- /dev/null +++ b/training/sampler.py | |||
| @@ -0,0 +1,154 @@ | |||
| 1 | from abc import ABC, abstractmethod | ||
| 2 | |||
| 3 | import numpy as np | ||
| 4 | import torch | ||
| 5 | import torch.distributed as dist | ||
| 6 | |||
| 7 | |||
| 8 | def create_named_schedule_sampler(name, num_timesteps): | ||
| 9 | """ | ||
| 10 | Create a ScheduleSampler from a library of pre-defined samplers. | ||
| 11 | |||
| 12 | :param name: the name of the sampler. | ||
| 13 | :param diffusion: the diffusion object to sample for. | ||
| 14 | """ | ||
| 15 | if name == "uniform": | ||
| 16 | return UniformSampler(num_timesteps) | ||
| 17 | elif name == "loss-second-moment": | ||
| 18 | return LossSecondMomentResampler(num_timesteps) | ||
| 19 | else: | ||
| 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") | ||
| 21 | |||
| 22 | |||
| 23 | class ScheduleSampler(ABC): | ||
| 24 | """ | ||
| 25 | A distribution over timesteps in the diffusion process, intended to reduce | ||
| 26 | variance of the objective. | ||
| 27 | |||
| 28 | By default, samplers perform unbiased importance sampling, in which the | ||
| 29 | objective's mean is unchanged. | ||
| 30 | However, subclasses may override sample() to change how the resampled | ||
| 31 | terms are reweighted, allowing for actual changes in the objective. | ||
| 32 | """ | ||
| 33 | |||
| 34 | @abstractmethod | ||
| 35 | def weights(self): | ||
| 36 | """ | ||
| 37 | Get a numpy array of weights, one per diffusion step. | ||
| 38 | |||
| 39 | The weights needn't be normalized, but must be positive. | ||
| 40 | """ | ||
| 41 | |||
| 42 | def sample(self, batch_size, device): | ||
| 43 | """ | ||
| 44 | Importance-sample timesteps for a batch. | ||
| 45 | |||
| 46 | :param batch_size: the number of timesteps. | ||
| 47 | :param device: the torch device to save to. | ||
| 48 | :return: a tuple (timesteps, weights): | ||
| 49 | - timesteps: a tensor of timestep indices. | ||
| 50 | - weights: a tensor of weights to scale the resulting losses. | ||
| 51 | """ | ||
| 52 | w = self.weights() | ||
| 53 | p = w / np.sum(w) | ||
| 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) | ||
| 55 | indices = torch.from_numpy(indices_np).long().to(device) | ||
| 56 | weights_np = 1 / (len(p) * p[indices_np]) | ||
| 57 | weights = torch.from_numpy(weights_np).float().to(device) | ||
| 58 | return indices, weights | ||
| 59 | |||
| 60 | |||
| 61 | class UniformSampler(ScheduleSampler): | ||
| 62 | def __init__(self, num_timesteps): | ||
| 63 | self.num_timesteps = num_timesteps | ||
| 64 | self._weights = np.ones([num_timesteps]) | ||
| 65 | |||
| 66 | def weights(self): | ||
| 67 | return self._weights | ||
| 68 | |||
| 69 | |||
| 70 | class LossAwareSampler(ScheduleSampler): | ||
| 71 | def update_with_local_losses(self, local_ts, local_losses): | ||
| 72 | """ | ||
| 73 | Update the reweighting using losses from a model. | ||
| 74 | |||
| 75 | Call this method from each rank with a batch of timesteps and the | ||
| 76 | corresponding losses for each of those timesteps. | ||
| 77 | This method will perform synchronization to make sure all of the ranks | ||
| 78 | maintain the exact same reweighting. | ||
| 79 | |||
| 80 | :param local_ts: an integer Tensor of timesteps. | ||
| 81 | :param local_losses: a 1D Tensor of losses. | ||
| 82 | """ | ||
| 83 | batch_sizes = [ | ||
| 84 | torch.tensor([0], dtype=torch.int32, device=local_ts.device) | ||
| 85 | for _ in range(dist.get_world_size()) | ||
| 86 | ] | ||
| 87 | dist.all_gather( | ||
| 88 | batch_sizes, | ||
| 89 | torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), | ||
| 90 | ) | ||
| 91 | |||
| 92 | # Pad all_gather batches to be the maximum batch size. | ||
| 93 | batch_sizes = [x.item() for x in batch_sizes] | ||
| 94 | max_bs = max(batch_sizes) | ||
| 95 | |||
| 96 | timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] | ||
| 97 | loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] | ||
| 98 | dist.all_gather(timestep_batches, local_ts) | ||
| 99 | dist.all_gather(loss_batches, local_losses) | ||
| 100 | timesteps = [ | ||
| 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] | ||
| 102 | ] | ||
| 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] | ||
| 104 | self.update_with_all_losses(timesteps, losses) | ||
| 105 | |||
| 106 | @abstractmethod | ||
| 107 | def update_with_all_losses(self, ts, losses): | ||
| 108 | """ | ||
| 109 | Update the reweighting using losses from a model. | ||
| 110 | |||
| 111 | Sub-classes should override this method to update the reweighting | ||
| 112 | using losses from the model. | ||
| 113 | |||
| 114 | This method directly updates the reweighting without synchronizing | ||
| 115 | between workers. It is called by update_with_local_losses from all | ||
| 116 | ranks with identical arguments. Thus, it should have deterministic | ||
| 117 | behavior to maintain state across workers. | ||
| 118 | |||
| 119 | :param ts: a list of int timesteps. | ||
| 120 | :param losses: a list of float losses, one per timestep. | ||
| 121 | """ | ||
| 122 | |||
| 123 | |||
| 124 | class LossSecondMomentResampler(LossAwareSampler): | ||
| 125 | def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001): | ||
| 126 | self.num_timesteps = num_timesteps | ||
| 127 | self.history_per_term = history_per_term | ||
| 128 | self.uniform_prob = uniform_prob | ||
| 129 | self._loss_history = np.zeros( | ||
| 130 | [self.num_timesteps, history_per_term], dtype=np.float64 | ||
| 131 | ) | ||
| 132 | self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) | ||
| 133 | |||
| 134 | def weights(self): | ||
| 135 | if not self._warmed_up(): | ||
| 136 | return np.ones([self.num_timesteps], dtype=np.float64) | ||
| 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) | ||
| 138 | weights /= np.sum(weights) | ||
| 139 | weights *= 1 - self.uniform_prob | ||
| 140 | weights += self.uniform_prob / len(weights) | ||
| 141 | return weights | ||
| 142 | |||
| 143 | def update_with_all_losses(self, ts, losses): | ||
| 144 | for t, loss in zip(ts, losses): | ||
| 145 | if self._loss_counts[t] == self.history_per_term: | ||
| 146 | # Shift out the oldest loss term. | ||
| 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] | ||
| 148 | self._loss_history[t, -1] = loss | ||
| 149 | else: | ||
| 150 | self._loss_history[t, self._loss_counts[t]] = loss | ||
| 151 | self._loss_counts[t] += 1 | ||
| 152 | |||
| 153 | def _warmed_up(self): | ||
| 154 | return (self._loss_counts == self.history_per_term).all() | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 3f4dbbc..0c0f633 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -120,11 +120,11 @@ def lora_strategy_callbacks( | |||
| 120 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 120 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 121 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 121 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 122 | 122 | ||
| 123 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 123 | # for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
| 124 | text_encoder_.text_model.embeddings.save_embed( | 124 | # text_encoder_.text_model.embeddings.save_embed( |
| 125 | ids, | 125 | # ids, |
| 126 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 126 | # checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
| 127 | ) | 127 | # ) |
| 128 | 128 | ||
| 129 | if not pti_mode: | 129 | if not pti_mode: |
| 130 | lora_config = {} | 130 | lora_config = {} |
