From 3dd234d8fe86b7d813aec9f43aeb765a88b6a916 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Wed, 5 Apr 2023 17:33:06 +0200
Subject: Improved slerp noise offset: Dedicated black image instead of
 negative offset

---
 util/slerp.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 59 insertions(+)
 create mode 100644 util/slerp.py

(limited to 'util')

diff --git a/util/slerp.py b/util/slerp.py
new file mode 100644
index 0000000..5ab6bf9
--- /dev/null
+++ b/util/slerp.py
@@ -0,0 +1,59 @@
+import torch
+
+
+def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1):
+    """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`.
+
+    `DOT_THR` determines when the vectors are too close to parallel.
+        If they are too close, then a regular linear interpolation is used.
+
+    `to_cpu` is a flag that optionally computes SLERP on the CPU.
+        If the input tensors were on a GPU, it moves them back after the computation.  
+
+    `zdim` is the feature dimension over which to compute norms and find angles.
+        For example: if a sequence of 5 vectors is input with shape [5, 768]
+        Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768.
+
+    Theory Reference:
+    https://splines.readthedocs.io/en/latest/rotation/slerp.html
+    PyTorch reference:
+    https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
+    Numpy reference: 
+    https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
+    """
+
+    # check if we need to move to the cpu
+    if to_cpu:
+        orig_device = v1.device
+        v1, v2 = v1.to('cpu'), v2.to('cpu')
+
+    # take the dot product between normalized vectors
+    v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True)
+    v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True)
+    dot = (v1_norm * v2_norm).sum(zdim)
+
+    for _ in range(len(dot.shape), len(v1.shape)):
+        dot = dot[..., None]
+
+    # if the vectors are too close, return a simple linear interpolation
+    if (torch.abs(dot) > DOT_THR).any():
+        res = (1 - t) * v1 + t * v2
+    else:
+        # compute the angle terms we need
+        theta = torch.acos(dot)
+        theta_t = theta * t
+        sin_theta = torch.sin(theta)
+        sin_theta_t = torch.sin(theta_t)
+
+        # compute the sine scaling terms for the vectors
+        s1 = torch.sin(theta - theta_t) / sin_theta
+        s2 = sin_theta_t / sin_theta
+
+        # interpolate the vectors
+        res = s1 * v1 + s2 * v2
+
+    # check if we need to move them back to the original device
+    if to_cpu:
+        res.to(orig_device)
+
+    return res
-- 
cgit v1.2.3-70-g09d2