summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py108
1 files changed, 41 insertions, 67 deletions
diff --git a/training/functional.py b/training/functional.py
index cc079ef..8917eb7 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -94,6 +94,8 @@ def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
94 pretrained_model_name_or_path, subfolder="scheduler" 94 pretrained_model_name_or_path, subfolder="scheduler"
95 ) 95 )
96 96
97 prepare_scheduler_for_custom_training(noise_scheduler, "cuda")
98
97 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler 99 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
98 100
99 101
@@ -273,68 +275,39 @@ def add_placeholder_tokens(
273 return placeholder_token_ids, initializer_token_ids 275 return placeholder_token_ids, initializer_token_ids
274 276
275 277
276def compute_snr(timesteps, noise_scheduler): 278def prepare_scheduler_for_custom_training(noise_scheduler, device):
277 """ 279 if hasattr(noise_scheduler, "all_snr"):
278 Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 280 return
279 """ 281
280 alphas_cumprod = noise_scheduler.alphas_cumprod
281 sqrt_alphas_cumprod = alphas_cumprod**0.5
282 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
283
284 # Expand the tensors.
285 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
286 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
287 timesteps
288 ].float()
289 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
290 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
291 alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
292
293 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
294 device=timesteps.device
295 )[timesteps].float()
296 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
297 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
298 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
299
300 # Compute SNR.
301 snr = (alpha / sigma) ** 2
302 return snr
303
304
305def get_original(
306 noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor
307):
308 alphas_cumprod = noise_scheduler.alphas_cumprod 282 alphas_cumprod = noise_scheduler.alphas_cumprod
309 sqrt_alphas_cumprod = alphas_cumprod**0.5 283 sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
310 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 284 sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
311 285 alpha = sqrt_alphas_cumprod
312 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ 286 sigma = sqrt_one_minus_alphas_cumprod
313 timesteps 287 all_snr = (alpha / sigma) ** 2
314 ].float() 288
315 while len(sqrt_alphas_cumprod.shape) < len(sample.shape): 289 noise_scheduler.all_snr = all_snr.to(device)
316 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 290
317 alpha = sqrt_alphas_cumprod.expand(sample.shape) 291
318 292def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
319 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( 293 snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
320 device=timesteps.device 294 gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
321 )[timesteps].float() 295 snr_weight = torch.minimum(
322 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): 296 gamma_over_snr, torch.ones_like(gamma_over_snr)
323 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 297 ).float() # from paper
324 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) 298 loss = loss * snr_weight
299 return loss
325 300
326 if noise_scheduler.config.prediction_type == "epsilon": 301
327 pred_original_sample = (sample - sigma * model_output) / alpha 302def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
328 elif noise_scheduler.config.prediction_type == "sample": 303 snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
329 pred_original_sample = model_output 304 snr_t = torch.minimum(
330 elif noise_scheduler.config.prediction_type == "v_prediction": 305 snr_t, torch.ones_like(snr_t) * 1000
331 pred_original_sample = alpha * sample - sigma * model_output 306 ) # if timestep is 0, snr_t is inf, so limit it to 1000
332 else: 307 scale = snr_t / (snr_t + 1)
333 raise ValueError( 308
334 f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" 309 loss = loss * scale
335 " `v_prediction` for the DDPMScheduler." 310 return loss
336 )
337 return pred_original_sample
338 311
339 312
340def loss_step( 313def loss_step(
@@ -347,6 +320,7 @@ def loss_step(
347 seed: int, 320 seed: int,
348 input_pertubation: float, 321 input_pertubation: float,
349 min_snr_gamma: int, 322 min_snr_gamma: int,
323 scale_v_pred_loss_like_noise_pred: bool,
350 step: int, 324 step: int,
351 batch: dict[str, Any], 325 batch: dict[str, Any],
352 cache: dict[Any, Any], 326 cache: dict[Any, Any],
@@ -433,14 +407,12 @@ def loss_step(
433 loss = loss.mean([1, 2, 3]) 407 loss = loss.mean([1, 2, 3])
434 408
435 if min_snr_gamma != 0: 409 if min_snr_gamma != 0:
436 snr = compute_snr(timesteps, noise_scheduler) 410 loss = apply_snr_weight(loss, timesteps, noise_scheduler, min_snr_gamma)
437 mse_loss_weights = ( 411
438 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( 412 if scale_v_pred_loss_like_noise_pred:
439 dim=1 413 loss = scale_v_prediction_loss_like_noise_prediction(
440 )[0] 414 loss, timesteps, noise_scheduler
441 / snr
442 ) 415 )
443 loss = loss * mse_loss_weights
444 416
445 if isinstance(schedule_sampler, LossAwareSampler): 417 if isinstance(schedule_sampler, LossAwareSampler):
446 schedule_sampler.update_with_all_losses(timesteps, loss.detach()) 418 schedule_sampler.update_with_all_losses(timesteps, loss.detach())
@@ -726,6 +698,7 @@ def train(
726 input_pertubation: float = 0.1, 698 input_pertubation: float = 0.1,
727 schedule_sampler: Optional[ScheduleSampler] = None, 699 schedule_sampler: Optional[ScheduleSampler] = None,
728 min_snr_gamma: int = 5, 700 min_snr_gamma: int = 5,
701 scale_v_pred_loss_like_noise_pred: bool = True,
729 avg_loss: AverageMeter = AverageMeter(), 702 avg_loss: AverageMeter = AverageMeter(),
730 avg_acc: AverageMeter = AverageMeter(), 703 avg_acc: AverageMeter = AverageMeter(),
731 avg_loss_val: AverageMeter = AverageMeter(), 704 avg_loss_val: AverageMeter = AverageMeter(),
@@ -785,6 +758,7 @@ def train(
785 seed, 758 seed,
786 input_pertubation, 759 input_pertubation,
787 min_snr_gamma, 760 min_snr_gamma,
761 scale_v_pred_loss_like_noise_pred,
788 ) 762 )
789 763
790 train_loop( 764 train_loop(