summaryrefslogtreecommitdiffstats
path: root/training/sampler.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/sampler.py')
-rw-r--r--training/sampler.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/training/sampler.py b/training/sampler.py
new file mode 100644
index 0000000..8afe255
--- /dev/null
+++ b/training/sampler.py
@@ -0,0 +1,154 @@
1from abc import ABC, abstractmethod
2
3import numpy as np
4import torch
5import torch.distributed as dist
6
7
8def create_named_schedule_sampler(name, num_timesteps):
9 """
10 Create a ScheduleSampler from a library of pre-defined samplers.
11
12 :param name: the name of the sampler.
13 :param diffusion: the diffusion object to sample for.
14 """
15 if name == "uniform":
16 return UniformSampler(num_timesteps)
17 elif name == "loss-second-moment":
18 return LossSecondMomentResampler(num_timesteps)
19 else:
20 raise NotImplementedError(f"unknown schedule sampler: {name}")
21
22
23class ScheduleSampler(ABC):
24 """
25 A distribution over timesteps in the diffusion process, intended to reduce
26 variance of the objective.
27
28 By default, samplers perform unbiased importance sampling, in which the
29 objective's mean is unchanged.
30 However, subclasses may override sample() to change how the resampled
31 terms are reweighted, allowing for actual changes in the objective.
32 """
33
34 @abstractmethod
35 def weights(self):
36 """
37 Get a numpy array of weights, one per diffusion step.
38
39 The weights needn't be normalized, but must be positive.
40 """
41
42 def sample(self, batch_size, device):
43 """
44 Importance-sample timesteps for a batch.
45
46 :param batch_size: the number of timesteps.
47 :param device: the torch device to save to.
48 :return: a tuple (timesteps, weights):
49 - timesteps: a tensor of timestep indices.
50 - weights: a tensor of weights to scale the resulting losses.
51 """
52 w = self.weights()
53 p = w / np.sum(w)
54 indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 indices = torch.from_numpy(indices_np).long().to(device)
56 weights_np = 1 / (len(p) * p[indices_np])
57 weights = torch.from_numpy(weights_np).float().to(device)
58 return indices, weights
59
60
61class UniformSampler(ScheduleSampler):
62 def __init__(self, num_timesteps):
63 self.num_timesteps = num_timesteps
64 self._weights = np.ones([num_timesteps])
65
66 def weights(self):
67 return self._weights
68
69
70class LossAwareSampler(ScheduleSampler):
71 def update_with_local_losses(self, local_ts, local_losses):
72 """
73 Update the reweighting using losses from a model.
74
75 Call this method from each rank with a batch of timesteps and the
76 corresponding losses for each of those timesteps.
77 This method will perform synchronization to make sure all of the ranks
78 maintain the exact same reweighting.
79
80 :param local_ts: an integer Tensor of timesteps.
81 :param local_losses: a 1D Tensor of losses.
82 """
83 batch_sizes = [
84 torch.tensor([0], dtype=torch.int32, device=local_ts.device)
85 for _ in range(dist.get_world_size())
86 ]
87 dist.all_gather(
88 batch_sizes,
89 torch.tensor([len(local_ts)], dtype=torch.int32, device=local_ts.device),
90 )
91
92 # Pad all_gather batches to be the maximum batch size.
93 batch_sizes = [x.item() for x in batch_sizes]
94 max_bs = max(batch_sizes)
95
96 timestep_batches = [torch.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 loss_batches = [torch.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 dist.all_gather(timestep_batches, local_ts)
99 dist.all_gather(loss_batches, local_losses)
100 timesteps = [
101 x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 ]
103 losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 self.update_with_all_losses(timesteps, losses)
105
106 @abstractmethod
107 def update_with_all_losses(self, ts, losses):
108 """
109 Update the reweighting using losses from a model.
110
111 Sub-classes should override this method to update the reweighting
112 using losses from the model.
113
114 This method directly updates the reweighting without synchronizing
115 between workers. It is called by update_with_local_losses from all
116 ranks with identical arguments. Thus, it should have deterministic
117 behavior to maintain state across workers.
118
119 :param ts: a list of int timesteps.
120 :param losses: a list of float losses, one per timestep.
121 """
122
123
124class LossSecondMomentResampler(LossAwareSampler):
125 def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001):
126 self.num_timesteps = num_timesteps
127 self.history_per_term = history_per_term
128 self.uniform_prob = uniform_prob
129 self._loss_history = np.zeros(
130 [self.num_timesteps, history_per_term], dtype=np.float64
131 )
132 self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int)
133
134 def weights(self):
135 if not self._warmed_up():
136 return np.ones([self.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights)
141 return weights
142
143 def update_with_all_losses(self, ts, losses):
144 for t, loss in zip(ts, losses):
145 if self._loss_counts[t] == self.history_per_term:
146 # Shift out the oldest loss term.
147 self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 self._loss_history[t, -1] = loss
149 else:
150 self._loss_history[t, self._loss_counts[t]] = loss
151 self._loss_counts[t] += 1
152
153 def _warmed_up(self):
154 return (self._loss_counts == self.history_per_term).all()