diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-05 17:33:06 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-05 17:33:06 +0200 |
| commit | 3dd234d8fe86b7d813aec9f43aeb765a88b6a916 (patch) | |
| tree | 9f8f4eb3cc01fe1578c84b54e11356aed280240b /util | |
| parent | New offset noise test (diff) | |
| download | textual-inversion-diff-3dd234d8fe86b7d813aec9f43aeb765a88b6a916.tar.gz textual-inversion-diff-3dd234d8fe86b7d813aec9f43aeb765a88b6a916.tar.bz2 textual-inversion-diff-3dd234d8fe86b7d813aec9f43aeb765a88b6a916.zip | |
Improved slerp noise offset: Dedicated black image instead of negative offset
Diffstat (limited to 'util')
| -rw-r--r-- | util/slerp.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/util/slerp.py b/util/slerp.py new file mode 100644 index 0000000..5ab6bf9 --- /dev/null +++ b/util/slerp.py | |||
| @@ -0,0 +1,59 @@ | |||
| 1 | import torch | ||
| 2 | |||
| 3 | |||
| 4 | def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): | ||
| 5 | """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. | ||
| 6 | |||
| 7 | `DOT_THR` determines when the vectors are too close to parallel. | ||
| 8 | If they are too close, then a regular linear interpolation is used. | ||
| 9 | |||
| 10 | `to_cpu` is a flag that optionally computes SLERP on the CPU. | ||
| 11 | If the input tensors were on a GPU, it moves them back after the computation. | ||
| 12 | |||
| 13 | `zdim` is the feature dimension over which to compute norms and find angles. | ||
| 14 | For example: if a sequence of 5 vectors is input with shape [5, 768] | ||
| 15 | Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. | ||
| 16 | |||
| 17 | Theory Reference: | ||
| 18 | https://splines.readthedocs.io/en/latest/rotation/slerp.html | ||
| 19 | PyTorch reference: | ||
| 20 | https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 | ||
| 21 | Numpy reference: | ||
| 22 | https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c | ||
| 23 | """ | ||
| 24 | |||
| 25 | # check if we need to move to the cpu | ||
| 26 | if to_cpu: | ||
| 27 | orig_device = v1.device | ||
| 28 | v1, v2 = v1.to('cpu'), v2.to('cpu') | ||
| 29 | |||
| 30 | # take the dot product between normalized vectors | ||
| 31 | v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) | ||
| 32 | v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) | ||
| 33 | dot = (v1_norm * v2_norm).sum(zdim) | ||
| 34 | |||
| 35 | for _ in range(len(dot.shape), len(v1.shape)): | ||
| 36 | dot = dot[..., None] | ||
| 37 | |||
| 38 | # if the vectors are too close, return a simple linear interpolation | ||
| 39 | if (torch.abs(dot) > DOT_THR).any(): | ||
| 40 | res = (1 - t) * v1 + t * v2 | ||
| 41 | else: | ||
| 42 | # compute the angle terms we need | ||
| 43 | theta = torch.acos(dot) | ||
| 44 | theta_t = theta * t | ||
| 45 | sin_theta = torch.sin(theta) | ||
| 46 | sin_theta_t = torch.sin(theta_t) | ||
| 47 | |||
| 48 | # compute the sine scaling terms for the vectors | ||
| 49 | s1 = torch.sin(theta - theta_t) / sin_theta | ||
| 50 | s2 = sin_theta_t / sin_theta | ||
| 51 | |||
| 52 | # interpolate the vectors | ||
| 53 | res = s1 * v1 + s2 * v2 | ||
| 54 | |||
| 55 | # check if we need to move them back to the original device | ||
| 56 | if to_cpu: | ||
| 57 | res.to(orig_device) | ||
| 58 | |||
| 59 | return res | ||
