summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-05 17:33:06 +0200
committerVolpeon <git@volpeon.ink>2023-04-05 17:33:06 +0200
commit3dd234d8fe86b7d813aec9f43aeb765a88b6a916 (patch)
tree9f8f4eb3cc01fe1578c84b54e11356aed280240b /training
parentNew offset noise test (diff)
downloadtextual-inversion-diff-3dd234d8fe86b7d813aec9f43aeb765a88b6a916.tar.gz
textual-inversion-diff-3dd234d8fe86b7d813aec9f43aeb765a88b6a916.tar.bz2
textual-inversion-diff-3dd234d8fe86b7d813aec9f43aeb765a88b6a916.zip
Improved slerp noise offset: Dedicated black image instead of negative offset
Diffstat (limited to 'training')
-rw-r--r--training/functional.py112
1 files changed, 40 insertions, 72 deletions
diff --git a/training/functional.py b/training/functional.py
index e7f02cb..68071bc 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -23,7 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from training.util import AverageMeter 25from training.util import AverageMeter
26from util.noise import perlin_noise 26from util.slerp import slerp
27 27
28 28
29def const(result=None): 29def const(result=None):
@@ -270,62 +270,16 @@ def snr_weight(noisy_latents, latents, gamma):
270 ) 270 )
271 271
272 272
273def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): 273def make_solid_image(color: float, shape, vae, dtype, device, generator):
274 """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. 274 img = torch.tensor(
275 275 [[[[color]]]],
276 `DOT_THR` determines when the vectors are too close to parallel. 276 dtype=dtype,
277 If they are too close, then a regular linear interpolation is used. 277 device=device
278 278 ).expand(1, *shape)
279 `to_cpu` is a flag that optionally computes SLERP on the CPU. 279 img = img * 2 - 1
280 If the input tensors were on a GPU, it moves them back after the computation. 280 img = vae.encode(img).latent_dist.sample(generator=generator)
281 281 img *= vae.config.scaling_factor
282 `zdim` is the feature dimension over which to compute norms and find angles. 282 return img
283 For example: if a sequence of 5 vectors is input with shape [5, 768]
284 Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768.
285
286 Theory Reference:
287 https://splines.readthedocs.io/en/latest/rotation/slerp.html
288 PyTorch reference:
289 https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
290 Numpy reference:
291 https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
292 """
293
294 # check if we need to move to the cpu
295 if to_cpu:
296 orig_device = v1.device
297 v1, v2 = v1.to('cpu'), v2.to('cpu')
298
299 # take the dot product between normalized vectors
300 v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True)
301 v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True)
302 dot = (v1_norm * v2_norm).sum(zdim)
303
304 for _ in range(len(dot.shape), len(v1.shape)):
305 dot = dot[..., None]
306
307 # if the vectors are too close, return a simple linear interpolation
308 if (torch.abs(dot) > DOT_THR).any():
309 res = (1 - t) * v1 + t * v2
310 else:
311 # compute the angle terms we need
312 theta = torch.acos(dot)
313 theta_t = theta * t
314 sin_theta = torch.sin(theta)
315 sin_theta_t = torch.sin(theta_t)
316
317 # compute the sine scaling terms for the vectors
318 s1 = torch.sin(theta - theta_t) / sin_theta
319 s2 = sin_theta_t / sin_theta
320
321 # interpolate the vectors
322 res = s1 * v1 + s2 * v2
323
324 # check if we need to move them back to the original device
325 if to_cpu:
326 res.to(orig_device)
327
328 return res
329 283
330 284
331def loss_step( 285def loss_step(
@@ -361,20 +315,29 @@ def loss_step(
361 ) 315 )
362 316
363 if offset_noise_strength != 0: 317 if offset_noise_strength != 0:
364 cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" 318 solid_image = partial(
365 319 make_solid_image,
366 if cache_key not in cache: 320 shape=images.shape[1:],
367 img_white = torch.tensor( 321 vae=vae,
368 [[[[1]]]], 322 dtype=latents.dtype,
369 dtype=latents.dtype, 323 device=latents.device,
370 device=latents.device 324 generator=generator
371 ).expand(1, images.shape[1], images.shape[2], images.shape[3]) 325 )
372 img_white = img_white * 2 - 1 326
373 img_white = vae.encode(img_white).latent_dist.sample(generator=generator) 327 white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}"
374 img_white *= vae.config.scaling_factor 328 black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}"
375 cache[cache_key] = img_white 329
330 if white_cache_key not in cache:
331 img_white = solid_image(1)
332 cache[white_cache_key] = img_white
376 else: 333 else:
377 img_white = cache[cache_key] 334 img_white = cache[white_cache_key]
335
336 if black_cache_key not in cache:
337 img_black = solid_image(0)
338 cache[black_cache_key] = img_black
339 else:
340 img_black = cache[black_cache_key]
378 341
379 offset_strength = torch.rand( 342 offset_strength = torch.rand(
380 (bsz, 1, 1, 1), 343 (bsz, 1, 1, 1),
@@ -384,8 +347,13 @@ def loss_step(
384 generator=generator 347 generator=generator
385 ) 348 )
386 offset_strength = offset_noise_strength * (offset_strength * 2 - 1) 349 offset_strength = offset_noise_strength * (offset_strength * 2 - 1)
387 offset_strength = offset_strength.expand(noise.shape) 350 offset_images = torch.where(
388 noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2)) 351 offset_strength >= 0,
352 img_white.expand(noise.shape),
353 img_black.expand(noise.shape)
354 )
355 offset_strength = offset_strength.abs().expand(noise.shape)
356 noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2))
389 357
390 # Sample a random timestep for each image 358 # Sample a random timestep for each image
391 timesteps = torch.randint( 359 timesteps = torch.randint(