diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 112 |
1 files changed, 40 insertions, 72 deletions
diff --git a/training/functional.py b/training/functional.py index e7f02cb..68071bc 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -23,7 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
25 | from training.util import AverageMeter | 25 | from training.util import AverageMeter |
26 | from util.noise import perlin_noise | 26 | from util.slerp import slerp |
27 | 27 | ||
28 | 28 | ||
29 | def const(result=None): | 29 | def const(result=None): |
@@ -270,62 +270,16 @@ def snr_weight(noisy_latents, latents, gamma): | |||
270 | ) | 270 | ) |
271 | 271 | ||
272 | 272 | ||
273 | def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): | 273 | def make_solid_image(color: float, shape, vae, dtype, device, generator): |
274 | """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. | 274 | img = torch.tensor( |
275 | 275 | [[[[color]]]], | |
276 | `DOT_THR` determines when the vectors are too close to parallel. | 276 | dtype=dtype, |
277 | If they are too close, then a regular linear interpolation is used. | 277 | device=device |
278 | 278 | ).expand(1, *shape) | |
279 | `to_cpu` is a flag that optionally computes SLERP on the CPU. | 279 | img = img * 2 - 1 |
280 | If the input tensors were on a GPU, it moves them back after the computation. | 280 | img = vae.encode(img).latent_dist.sample(generator=generator) |
281 | 281 | img *= vae.config.scaling_factor | |
282 | `zdim` is the feature dimension over which to compute norms and find angles. | 282 | return img |
283 | For example: if a sequence of 5 vectors is input with shape [5, 768] | ||
284 | Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. | ||
285 | |||
286 | Theory Reference: | ||
287 | https://splines.readthedocs.io/en/latest/rotation/slerp.html | ||
288 | PyTorch reference: | ||
289 | https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 | ||
290 | Numpy reference: | ||
291 | https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c | ||
292 | """ | ||
293 | |||
294 | # check if we need to move to the cpu | ||
295 | if to_cpu: | ||
296 | orig_device = v1.device | ||
297 | v1, v2 = v1.to('cpu'), v2.to('cpu') | ||
298 | |||
299 | # take the dot product between normalized vectors | ||
300 | v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) | ||
301 | v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) | ||
302 | dot = (v1_norm * v2_norm).sum(zdim) | ||
303 | |||
304 | for _ in range(len(dot.shape), len(v1.shape)): | ||
305 | dot = dot[..., None] | ||
306 | |||
307 | # if the vectors are too close, return a simple linear interpolation | ||
308 | if (torch.abs(dot) > DOT_THR).any(): | ||
309 | res = (1 - t) * v1 + t * v2 | ||
310 | else: | ||
311 | # compute the angle terms we need | ||
312 | theta = torch.acos(dot) | ||
313 | theta_t = theta * t | ||
314 | sin_theta = torch.sin(theta) | ||
315 | sin_theta_t = torch.sin(theta_t) | ||
316 | |||
317 | # compute the sine scaling terms for the vectors | ||
318 | s1 = torch.sin(theta - theta_t) / sin_theta | ||
319 | s2 = sin_theta_t / sin_theta | ||
320 | |||
321 | # interpolate the vectors | ||
322 | res = s1 * v1 + s2 * v2 | ||
323 | |||
324 | # check if we need to move them back to the original device | ||
325 | if to_cpu: | ||
326 | res.to(orig_device) | ||
327 | |||
328 | return res | ||
329 | 283 | ||
330 | 284 | ||
331 | def loss_step( | 285 | def loss_step( |
@@ -361,20 +315,29 @@ def loss_step( | |||
361 | ) | 315 | ) |
362 | 316 | ||
363 | if offset_noise_strength != 0: | 317 | if offset_noise_strength != 0: |
364 | cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" | 318 | solid_image = partial( |
365 | 319 | make_solid_image, | |
366 | if cache_key not in cache: | 320 | shape=images.shape[1:], |
367 | img_white = torch.tensor( | 321 | vae=vae, |
368 | [[[[1]]]], | 322 | dtype=latents.dtype, |
369 | dtype=latents.dtype, | 323 | device=latents.device, |
370 | device=latents.device | 324 | generator=generator |
371 | ).expand(1, images.shape[1], images.shape[2], images.shape[3]) | 325 | ) |
372 | img_white = img_white * 2 - 1 | 326 | |
373 | img_white = vae.encode(img_white).latent_dist.sample(generator=generator) | 327 | white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" |
374 | img_white *= vae.config.scaling_factor | 328 | black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" |
375 | cache[cache_key] = img_white | 329 | |
330 | if white_cache_key not in cache: | ||
331 | img_white = solid_image(1) | ||
332 | cache[white_cache_key] = img_white | ||
376 | else: | 333 | else: |
377 | img_white = cache[cache_key] | 334 | img_white = cache[white_cache_key] |
335 | |||
336 | if black_cache_key not in cache: | ||
337 | img_black = solid_image(0) | ||
338 | cache[black_cache_key] = img_black | ||
339 | else: | ||
340 | img_black = cache[black_cache_key] | ||
378 | 341 | ||
379 | offset_strength = torch.rand( | 342 | offset_strength = torch.rand( |
380 | (bsz, 1, 1, 1), | 343 | (bsz, 1, 1, 1), |
@@ -384,8 +347,13 @@ def loss_step( | |||
384 | generator=generator | 347 | generator=generator |
385 | ) | 348 | ) |
386 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) | 349 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) |
387 | offset_strength = offset_strength.expand(noise.shape) | 350 | offset_images = torch.where( |
388 | noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2)) | 351 | offset_strength >= 0, |
352 | img_white.expand(noise.shape), | ||
353 | img_black.expand(noise.shape) | ||
354 | ) | ||
355 | offset_strength = offset_strength.abs().expand(noise.shape) | ||
356 | noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) | ||
389 | 357 | ||
390 | # Sample a random timestep for each image | 358 | # Sample a random timestep for each image |
391 | timesteps = torch.randint( | 359 | timesteps = torch.randint( |