summaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/slerp.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/util/slerp.py b/util/slerp.py
new file mode 100644
index 0000000..5ab6bf9
--- /dev/null
+++ b/util/slerp.py
@@ -0,0 +1,59 @@
1import torch
2
3
4def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1):
5 """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`.
6
7 `DOT_THR` determines when the vectors are too close to parallel.
8 If they are too close, then a regular linear interpolation is used.
9
10 `to_cpu` is a flag that optionally computes SLERP on the CPU.
11 If the input tensors were on a GPU, it moves them back after the computation.
12
13 `zdim` is the feature dimension over which to compute norms and find angles.
14 For example: if a sequence of 5 vectors is input with shape [5, 768]
15 Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768.
16
17 Theory Reference:
18 https://splines.readthedocs.io/en/latest/rotation/slerp.html
19 PyTorch reference:
20 https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
21 Numpy reference:
22 https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
23 """
24
25 # check if we need to move to the cpu
26 if to_cpu:
27 orig_device = v1.device
28 v1, v2 = v1.to('cpu'), v2.to('cpu')
29
30 # take the dot product between normalized vectors
31 v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True)
32 v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True)
33 dot = (v1_norm * v2_norm).sum(zdim)
34
35 for _ in range(len(dot.shape), len(v1.shape)):
36 dot = dot[..., None]
37
38 # if the vectors are too close, return a simple linear interpolation
39 if (torch.abs(dot) > DOT_THR).any():
40 res = (1 - t) * v1 + t * v2
41 else:
42 # compute the angle terms we need
43 theta = torch.acos(dot)
44 theta_t = theta * t
45 sin_theta = torch.sin(theta)
46 sin_theta_t = torch.sin(theta_t)
47
48 # compute the sine scaling terms for the vectors
49 s1 = torch.sin(theta - theta_t) / sin_theta
50 s2 = sin_theta_t / sin_theta
51
52 # interpolate the vectors
53 res = s1 * v1 + s2 * v2
54
55 # check if we need to move them back to the original device
56 if to_cpu:
57 res.to(orig_device)
58
59 return res