From 3dd234d8fe86b7d813aec9f43aeb765a88b6a916 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 5 Apr 2023 17:33:06 +0200 Subject: Improved slerp noise offset: Dedicated black image instead of negative offset --- training/functional.py | 112 ++++++++++++++++++------------------------------- 1 file changed, 40 insertions(+), 72 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index e7f02cb..68071bc 100644 --- a/training/functional.py +++ b/training/functional.py @@ -23,7 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter -from util.noise import perlin_noise +from util.slerp import slerp def const(result=None): @@ -270,62 +270,16 @@ def snr_weight(noisy_latents, latents, gamma): ) -def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): - """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. - - `DOT_THR` determines when the vectors are too close to parallel. - If they are too close, then a regular linear interpolation is used. - - `to_cpu` is a flag that optionally computes SLERP on the CPU. - If the input tensors were on a GPU, it moves them back after the computation. - - `zdim` is the feature dimension over which to compute norms and find angles. - For example: if a sequence of 5 vectors is input with shape [5, 768] - Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. - - Theory Reference: - https://splines.readthedocs.io/en/latest/rotation/slerp.html - PyTorch reference: - https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 - Numpy reference: - https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c - """ - - # check if we need to move to the cpu - if to_cpu: - orig_device = v1.device - v1, v2 = v1.to('cpu'), v2.to('cpu') - - # take the dot product between normalized vectors - v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) - v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) - dot = (v1_norm * v2_norm).sum(zdim) - - for _ in range(len(dot.shape), len(v1.shape)): - dot = dot[..., None] - - # if the vectors are too close, return a simple linear interpolation - if (torch.abs(dot) > DOT_THR).any(): - res = (1 - t) * v1 + t * v2 - else: - # compute the angle terms we need - theta = torch.acos(dot) - theta_t = theta * t - sin_theta = torch.sin(theta) - sin_theta_t = torch.sin(theta_t) - - # compute the sine scaling terms for the vectors - s1 = torch.sin(theta - theta_t) / sin_theta - s2 = sin_theta_t / sin_theta - - # interpolate the vectors - res = s1 * v1 + s2 * v2 - - # check if we need to move them back to the original device - if to_cpu: - res.to(orig_device) - - return res +def make_solid_image(color: float, shape, vae, dtype, device, generator): + img = torch.tensor( + [[[[color]]]], + dtype=dtype, + device=device + ).expand(1, *shape) + img = img * 2 - 1 + img = vae.encode(img).latent_dist.sample(generator=generator) + img *= vae.config.scaling_factor + return img def loss_step( @@ -361,20 +315,29 @@ def loss_step( ) if offset_noise_strength != 0: - cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" - - if cache_key not in cache: - img_white = torch.tensor( - [[[[1]]]], - dtype=latents.dtype, - device=latents.device - ).expand(1, images.shape[1], images.shape[2], images.shape[3]) - img_white = img_white * 2 - 1 - img_white = vae.encode(img_white).latent_dist.sample(generator=generator) - img_white *= vae.config.scaling_factor - cache[cache_key] = img_white + solid_image = partial( + make_solid_image, + shape=images.shape[1:], + vae=vae, + dtype=latents.dtype, + device=latents.device, + generator=generator + ) + + white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" + black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" + + if white_cache_key not in cache: + img_white = solid_image(1) + cache[white_cache_key] = img_white else: - img_white = cache[cache_key] + img_white = cache[white_cache_key] + + if black_cache_key not in cache: + img_black = solid_image(0) + cache[black_cache_key] = img_black + else: + img_black = cache[black_cache_key] offset_strength = torch.rand( (bsz, 1, 1, 1), @@ -384,8 +347,13 @@ def loss_step( generator=generator ) offset_strength = offset_noise_strength * (offset_strength * 2 - 1) - offset_strength = offset_strength.expand(noise.shape) - noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2)) + offset_images = torch.where( + offset_strength >= 0, + img_white.expand(noise.shape), + img_black.expand(noise.shape) + ) + offset_strength = offset_strength.abs().expand(noise.shape) + noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) # Sample a random timestep for each image timesteps = torch.randint( -- cgit v1.2.3-54-g00ecf