From 3dd234d8fe86b7d813aec9f43aeb765a88b6a916 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 5 Apr 2023 17:33:06 +0200 Subject: Improved slerp noise offset: Dedicated black image instead of negative offset --- util/slerp.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 util/slerp.py (limited to 'util') diff --git a/util/slerp.py b/util/slerp.py new file mode 100644 index 0000000..5ab6bf9 --- /dev/null +++ b/util/slerp.py @@ -0,0 +1,59 @@ +import torch + + +def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): + """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. + + `DOT_THR` determines when the vectors are too close to parallel. + If they are too close, then a regular linear interpolation is used. + + `to_cpu` is a flag that optionally computes SLERP on the CPU. + If the input tensors were on a GPU, it moves them back after the computation. + + `zdim` is the feature dimension over which to compute norms and find angles. + For example: if a sequence of 5 vectors is input with shape [5, 768] + Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. + + Theory Reference: + https://splines.readthedocs.io/en/latest/rotation/slerp.html + PyTorch reference: + https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 + Numpy reference: + https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c + """ + + # check if we need to move to the cpu + if to_cpu: + orig_device = v1.device + v1, v2 = v1.to('cpu'), v2.to('cpu') + + # take the dot product between normalized vectors + v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) + v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) + dot = (v1_norm * v2_norm).sum(zdim) + + for _ in range(len(dot.shape), len(v1.shape)): + dot = dot[..., None] + + # if the vectors are too close, return a simple linear interpolation + if (torch.abs(dot) > DOT_THR).any(): + res = (1 - t) * v1 + t * v2 + else: + # compute the angle terms we need + theta = torch.acos(dot) + theta_t = theta * t + sin_theta = torch.sin(theta) + sin_theta_t = torch.sin(theta_t) + + # compute the sine scaling terms for the vectors + s1 = torch.sin(theta - theta_t) / sin_theta + s2 = sin_theta_t / sin_theta + + # interpolate the vectors + res = s1 * v1 + s2 * v2 + + # check if we need to move them back to the original device + if to_cpu: + res.to(orig_device) + + return res -- cgit v1.2.3-70-g09d2