diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/slerp.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/util/slerp.py b/util/slerp.py new file mode 100644 index 0000000..5ab6bf9 --- /dev/null +++ b/util/slerp.py | |||
@@ -0,0 +1,59 @@ | |||
1 | import torch | ||
2 | |||
3 | |||
4 | def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): | ||
5 | """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. | ||
6 | |||
7 | `DOT_THR` determines when the vectors are too close to parallel. | ||
8 | If they are too close, then a regular linear interpolation is used. | ||
9 | |||
10 | `to_cpu` is a flag that optionally computes SLERP on the CPU. | ||
11 | If the input tensors were on a GPU, it moves them back after the computation. | ||
12 | |||
13 | `zdim` is the feature dimension over which to compute norms and find angles. | ||
14 | For example: if a sequence of 5 vectors is input with shape [5, 768] | ||
15 | Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. | ||
16 | |||
17 | Theory Reference: | ||
18 | https://splines.readthedocs.io/en/latest/rotation/slerp.html | ||
19 | PyTorch reference: | ||
20 | https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 | ||
21 | Numpy reference: | ||
22 | https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c | ||
23 | """ | ||
24 | |||
25 | # check if we need to move to the cpu | ||
26 | if to_cpu: | ||
27 | orig_device = v1.device | ||
28 | v1, v2 = v1.to('cpu'), v2.to('cpu') | ||
29 | |||
30 | # take the dot product between normalized vectors | ||
31 | v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) | ||
32 | v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) | ||
33 | dot = (v1_norm * v2_norm).sum(zdim) | ||
34 | |||
35 | for _ in range(len(dot.shape), len(v1.shape)): | ||
36 | dot = dot[..., None] | ||
37 | |||
38 | # if the vectors are too close, return a simple linear interpolation | ||
39 | if (torch.abs(dot) > DOT_THR).any(): | ||
40 | res = (1 - t) * v1 + t * v2 | ||
41 | else: | ||
42 | # compute the angle terms we need | ||
43 | theta = torch.acos(dot) | ||
44 | theta_t = theta * t | ||
45 | sin_theta = torch.sin(theta) | ||
46 | sin_theta_t = torch.sin(theta_t) | ||
47 | |||
48 | # compute the sine scaling terms for the vectors | ||
49 | s1 = torch.sin(theta - theta_t) / sin_theta | ||
50 | s2 = sin_theta_t / sin_theta | ||
51 | |||
52 | # interpolate the vectors | ||
53 | res = s1 * v1 + s2 * v2 | ||
54 | |||
55 | # check if we need to move them back to the original device | ||
56 | if to_cpu: | ||
57 | res.to(orig_device) | ||
58 | |||
59 | return res | ||