import torch def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. `DOT_THR` determines when the vectors are too close to parallel. If they are too close, then a regular linear interpolation is used. `to_cpu` is a flag that optionally computes SLERP on the CPU. If the input tensors were on a GPU, it moves them back after the computation. `zdim` is the feature dimension over which to compute norms and find angles. For example: if a sequence of 5 vectors is input with shape [5, 768] Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. Theory Reference: https://splines.readthedocs.io/en/latest/rotation/slerp.html PyTorch reference: https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 Numpy reference: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c """ # check if we need to move to the cpu if to_cpu: orig_device = v1.device v1, v2 = v1.to('cpu'), v2.to('cpu') # take the dot product between normalized vectors v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) dot = (v1_norm * v2_norm).sum(zdim) for _ in range(len(dot.shape), len(v1.shape)): dot = dot[..., None] # if the vectors are too close, return a simple linear interpolation if (torch.abs(dot) > DOT_THR).any(): res = (1 - t) * v1 + t * v2 else: # compute the angle terms we need theta = torch.acos(dot) theta_t = theta * t sin_theta = torch.sin(theta) sin_theta_t = torch.sin(theta_t) # compute the sine scaling terms for the vectors s1 = torch.sin(theta - theta_t) / sin_theta s2 = sin_theta_t / sin_theta # interpolate the vectors res = s1 * v1 + s2 * v2 # check if we need to move them back to the original device if to_cpu: res.to(orig_device) return res