1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
|
import torch
def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1):
"""SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`.
`DOT_THR` determines when the vectors are too close to parallel.
If they are too close, then a regular linear interpolation is used.
`to_cpu` is a flag that optionally computes SLERP on the CPU.
If the input tensors were on a GPU, it moves them back after the computation.
`zdim` is the feature dimension over which to compute norms and find angles.
For example: if a sequence of 5 vectors is input with shape [5, 768]
Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768.
Theory Reference:
https://splines.readthedocs.io/en/latest/rotation/slerp.html
PyTorch reference:
https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
Numpy reference:
https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
"""
# check if we need to move to the cpu
if to_cpu:
orig_device = v1.device
v1, v2 = v1.to('cpu'), v2.to('cpu')
# take the dot product between normalized vectors
v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True)
v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True)
dot = (v1_norm * v2_norm).sum(zdim)
for _ in range(len(dot.shape), len(v1.shape)):
dot = dot[..., None]
# if the vectors are too close, return a simple linear interpolation
if (torch.abs(dot) > DOT_THR).any():
res = (1 - t) * v1 + t * v2
else:
# compute the angle terms we need
theta = torch.acos(dot)
theta_t = theta * t
sin_theta = torch.sin(theta)
sin_theta_t = torch.sin(theta_t)
# compute the sine scaling terms for the vectors
s1 = torch.sin(theta - theta_t) / sin_theta
s2 = sin_theta_t / sin_theta
# interpolate the vectors
res = s1 * v1 + s2 * v2
# check if we need to move them back to the original device
if to_cpu:
res.to(orig_device)
return res
|