summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-21 13:46:36 +0100
committerVolpeon <git@volpeon.ink>2023-03-21 13:46:36 +0100
commitf5e0e98f6df9260a93fb650a0b97c85eb87b0fd3 (patch)
tree0d061f5fd8950d7ca7e0198731ee58980859dd18 /training/functional.py
parentRestore min SNR (diff)
downloadtextual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.gz
textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.bz2
textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.zip
Fixed SNR weighting, re-enabled xformers
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py
index 8dc2b9f..43ee356 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -251,6 +251,25 @@ def add_placeholder_tokens(
251 return placeholder_token_ids, initializer_token_ids 251 return placeholder_token_ids, initializer_token_ids
252 252
253 253
254def snr_weight(noisy_latents, latents, gamma):
255 if gamma:
256 sigma = torch.sub(noisy_latents, latents)
257 zeros = torch.zeros_like(sigma)
258 alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
259 sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
260 snr = torch.div(alpha_mean_sq, sigma_mean_sq)
261 gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
262 snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
263 return snr_weight
264
265 return torch.tensor(
266 [1],
267 dtype=latents.dtype,
268 layout=latents.layout,
269 device=latents.device,
270 )
271
272
254def loss_step( 273def loss_step(
255 vae: AutoencoderKL, 274 vae: AutoencoderKL,
256 noise_scheduler: SchedulerMixin, 275 noise_scheduler: SchedulerMixin,
@@ -308,21 +327,13 @@ def loss_step(
308 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 327 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
309 328
310 # Get the target for loss depending on the prediction type 329 # Get the target for loss depending on the prediction type
311 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()
312 snr = alpha_t / (1 - alpha_t)
313 min_snr = snr.clamp(max=min_snr_gamma)
314
315 if noise_scheduler.config.prediction_type == "epsilon": 330 if noise_scheduler.config.prediction_type == "epsilon":
316 target = noise 331 target = noise
317 loss_weight = min_snr / snr
318 elif noise_scheduler.config.prediction_type == "v_prediction": 332 elif noise_scheduler.config.prediction_type == "v_prediction":
319 target = noise_scheduler.get_velocity(latents, noise, timesteps) 333 target = noise_scheduler.get_velocity(latents, noise, timesteps)
320 loss_weight = min_snr / (snr + 1)
321 else: 334 else:
322 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 335 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
323 336
324 loss_weight = loss_weight[..., None, None, None]
325
326 if with_prior_preservation: 337 if with_prior_preservation:
327 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 338 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
328 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 339 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
@@ -339,7 +350,11 @@ def loss_step(
339 else: 350 else:
340 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 351 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
341 352
342 loss = (loss_weight * loss).mean([1, 2, 3]).mean() 353 loss = loss.mean([1, 2, 3])
354
355 loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma)
356 loss = (loss_weight * loss).mean()
357
343 acc = (model_pred == target).float().mean() 358 acc = (model_pred == target).float().mean()
344 359
345 return loss, acc, bsz 360 return loss, acc, bsz
@@ -412,7 +427,7 @@ def train_loop(
412 try: 427 try:
413 for epoch in range(num_epochs): 428 for epoch in range(num_epochs):
414 if accelerator.is_main_process: 429 if accelerator.is_main_process:
415 if epoch % sample_frequency == 0 and epoch != 0: 430 if epoch % sample_frequency == 0:
416 local_progress_bar.clear() 431 local_progress_bar.clear()
417 global_progress_bar.clear() 432 global_progress_bar.clear()
418 433