diff options
author | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-06 16:25:36 +0200 |
commit | 7b04d813739c0b5595295dffdc86cc41108db2d3 (patch) | |
tree | 8958b612f5d3d665866770ad553e1004aa4b6fb8 /training/sampler.py | |
parent | Update (diff) | |
download | textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2 textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip |
Update
Diffstat (limited to 'training/sampler.py')
-rw-r--r-- | training/sampler.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/training/sampler.py b/training/sampler.py new file mode 100644 index 0000000..8afe255 --- /dev/null +++ b/training/sampler.py | |||
@@ -0,0 +1,154 @@ | |||
1 | from abc import ABC, abstractmethod | ||
2 | |||
3 | import numpy as np | ||
4 | import torch | ||
5 | import torch.distributed as dist | ||
6 | |||
7 | |||
8 | def create_named_schedule_sampler(name, num_timesteps): | ||
9 | """ | ||
10 | Create a ScheduleSampler from a library of pre-defined samplers. | ||
11 | |||
12 | :param name: the name of the sampler. | ||
13 | :param diffusion: the diffusion object to sample for. | ||
14 | """ | ||
15 | if name == "uniform": | ||
16 | return UniformSampler(num_timesteps) | ||
17 | elif name == "loss-second-moment": | ||
18 | return LossSecondMomentResampler(num_timesteps) | ||
19 | else: | ||
20 | raise NotImplementedError(f"unknown schedule sampler: {name}") | ||
21 | |||
22 | |||
23 | class ScheduleSampler(ABC): | ||
24 | """ | ||
25 | A distribution over timesteps in the diffusion process, intended to reduce | ||
26 | variance of the objective. | ||
27 | |||
28 | By default, samplers perform unbiased importance sampling, in which the | ||
29 | objective's mean is unchanged. | ||
30 | However, subclasses may override sample() to change how the resampled | ||
31 | terms are reweighted, allowing for actual changes in the objective. | ||
32 | """ | ||
33 | |||
34 | @abstractmethod | ||
35 | def weights(self): | ||
36 | """ | ||
37 | Get a numpy array of weights, one per diffusion step. | ||
38 | |||
39 | The weights needn't be normalized, but must be positive. | ||
40 | """ | ||
41 | |||
42 | def sample(self, batch_size, device): | ||
43 | """ | ||
44 | Importance-sample timesteps for a batch. | ||
45 | |||
46 | :param batch_size: the number of timesteps. | ||
47 | :param device: the torch device to save to. | ||
48 | :return: a tuple (timesteps, weights): | ||
49 | - timesteps: a tensor of timestep indices. | ||
50 | - weights: a tensor of weights to scale the resulting losses. | ||
51 | """ | ||
52 | w = self.weights() | ||
53 | p = w / np.sum(w) | ||
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) | ||
55 | indices = torch.from_numpy(indices_np).long().to(device) | ||
56 | weights_np = 1 / (len(p) * p[indices_np]) | ||
57 | weights = torch.from_numpy(weights_np).float().to(device) | ||
58 | return indices, weights | ||
59 | |||
60 | |||
61 | class UniformSampler(ScheduleSampler): | ||
62 | def __init__(self, num_timesteps): | ||
63 | self.num_timesteps = num_timesteps | ||
64 | self._weights = np.ones([num_timesteps]) | ||
65 | |||
66 | def weights(self): | ||
67 | return self._weights | ||
68 | |||
69 | |||
70 | class LossAwareSampler(ScheduleSampler): | ||
71 | def update_with_local_losses(self, local_ts, local_losses): | ||
72 | """ | ||
73 | Update the reweighting using losses from a model. | ||
74 | |||
75 | Call this method from each rank with a batch of timesteps and the | ||
76 | corresponding losses for each of those timesteps. | ||
77 | This method will perform synchronization to make sure all of the ranks | ||
78 | maintain the exact same reweighting. | ||
79 | |||
80 | :param local_ts: an integer Tensor of timesteps. | ||
81 | :param local_losses: a 1D Tensor of losses. | ||
82 | """ | ||
83 | batch_sizes = [ | ||
84 | torch.tensor([0], dtype=torch.int32, device=local_ts.device) | ||
85 | for _ in range(dist.get_world_size()) | ||
86 | ] | ||
87 | dist.all_gather( | ||
88 | batch_sizes, | ||
89 | torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device), | ||
90 | ) | ||
91 | |||
92 | # Pad all_gather batches to be the maximum batch size. | ||
93 | batch_sizes = [x.item() for x in batch_sizes] | ||
94 | max_bs = max(batch_sizes) | ||
95 | |||
96 | timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes] | ||
97 | loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes] | ||
98 | dist.all_gather(timestep_batches, local_ts) | ||
99 | dist.all_gather(loss_batches, local_losses) | ||
100 | timesteps = [ | ||
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] | ||
102 | ] | ||
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] | ||
104 | self.update_with_all_losses(timesteps, losses) | ||
105 | |||
106 | @abstractmethod | ||
107 | def update_with_all_losses(self, ts, losses): | ||
108 | """ | ||
109 | Update the reweighting using losses from a model. | ||
110 | |||
111 | Sub-classes should override this method to update the reweighting | ||
112 | using losses from the model. | ||
113 | |||
114 | This method directly updates the reweighting without synchronizing | ||
115 | between workers. It is called by update_with_local_losses from all | ||
116 | ranks with identical arguments. Thus, it should have deterministic | ||
117 | behavior to maintain state across workers. | ||
118 | |||
119 | :param ts: a list of int timesteps. | ||
120 | :param losses: a list of float losses, one per timestep. | ||
121 | """ | ||
122 | |||
123 | |||
124 | class LossSecondMomentResampler(LossAwareSampler): | ||
125 | def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001): | ||
126 | self.num_timesteps = num_timesteps | ||
127 | self.history_per_term = history_per_term | ||
128 | self.uniform_prob = uniform_prob | ||
129 | self._loss_history = np.zeros( | ||
130 | [self.num_timesteps, history_per_term], dtype=np.float64 | ||
131 | ) | ||
132 | self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) | ||
133 | |||
134 | def weights(self): | ||
135 | if not self._warmed_up(): | ||
136 | return np.ones([self.num_timesteps], dtype=np.float64) | ||
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) | ||
138 | weights /= np.sum(weights) | ||
139 | weights *= 1 - self.uniform_prob | ||
140 | weights += self.uniform_prob / len(weights) | ||
141 | return weights | ||
142 | |||
143 | def update_with_all_losses(self, ts, losses): | ||
144 | for t, loss in zip(ts, losses): | ||
145 | if self._loss_counts[t] == self.history_per_term: | ||
146 | # Shift out the oldest loss term. | ||
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] | ||
148 | self._loss_history[t, -1] = loss | ||
149 | else: | ||
150 | self._loss_history[t, self._loss_counts[t]] = loss | ||
151 | self._loss_counts[t] += 1 | ||
152 | |||
153 | def _warmed_up(self): | ||
154 | return (self._loss_counts == self.history_per_term).all() | ||