diff options
author | Volpeon <git@volpeon.ink> | 2023-04-11 22:36:05 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-11 22:36:05 +0200 |
commit | f39286fa5c5840b67dadf8e85f5f5d7ff1414aab (patch) | |
tree | b56956444af8439404bb8eb9b82508ac5e2194b9 /training | |
parent | Store sample images in Tensorboard as well (diff) | |
download | textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.tar.gz textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.tar.bz2 textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.zip |
Experimental convnext discriminator support
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 83 |
1 files changed, 59 insertions, 24 deletions
diff --git a/training/functional.py b/training/functional.py index 2f7f837..be39776 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings |
25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
26 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
26 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
27 | from util.slerp import slerp | 28 | from util.slerp import slerp |
28 | 29 | ||
@@ -160,7 +161,8 @@ def save_samples( | |||
160 | 161 | ||
161 | for tracker in accelerator.trackers: | 162 | for tracker in accelerator.trackers: |
162 | if tracker.name == "tensorboard": | 163 | if tracker.name == "tensorboard": |
163 | tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 164 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") |
165 | pass | ||
164 | 166 | ||
165 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
166 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] |
@@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler): | |||
275 | return snr | 277 | return snr |
276 | 278 | ||
277 | 279 | ||
278 | def make_solid_image(color: float, shape, vae, dtype, device, generator): | 280 | def get_original( |
279 | img = torch.tensor( | 281 | noise_scheduler, |
280 | [[[[color]]]], | 282 | model_output, |
281 | dtype=dtype, | 283 | sample: torch.FloatTensor, |
282 | device=device | 284 | timesteps: torch.IntTensor |
283 | ).expand(1, *shape) | 285 | ): |
284 | img = img * 2 - 1 | 286 | alphas_cumprod = noise_scheduler.alphas_cumprod |
285 | img = vae.encode(img).latent_dist.sample(generator=generator) | 287 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
286 | img *= vae.config.scaling_factor | 288 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
287 | return img | 289 | |
290 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
291 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | ||
292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
293 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | ||
294 | |||
295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
296 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | ||
297 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | ||
298 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | ||
299 | |||
300 | if noise_scheduler.config.prediction_type == "epsilon": | ||
301 | pred_original_sample = (sample - sigma * model_output) / alpha | ||
302 | elif noise_scheduler.config.prediction_type == "sample": | ||
303 | pred_original_sample = model_output | ||
304 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
305 | pred_original_sample = alpha * sample - sigma * model_output | ||
306 | else: | ||
307 | raise ValueError( | ||
308 | f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | ||
309 | " `v_prediction` for the DDPMScheduler." | ||
310 | ) | ||
311 | return pred_original_sample | ||
288 | 312 | ||
289 | 313 | ||
290 | def loss_step( | 314 | def loss_step( |
@@ -296,6 +320,7 @@ def loss_step( | |||
296 | prior_loss_weight: float, | 320 | prior_loss_weight: float, |
297 | seed: int, | 321 | seed: int, |
298 | offset_noise_strength: float, | 322 | offset_noise_strength: float, |
323 | disc: Optional[ConvNeXtDiscriminator], | ||
299 | min_snr_gamma: int, | 324 | min_snr_gamma: int, |
300 | step: int, | 325 | step: int, |
301 | batch: dict[str, Any], | 326 | batch: dict[str, Any], |
@@ -373,23 +398,31 @@ def loss_step( | |||
373 | else: | 398 | else: |
374 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 399 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
375 | 400 | ||
376 | if guidance_scale == 0 and prior_loss_weight != 0: | 401 | if disc is None: |
377 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 402 | if guidance_scale == 0 and prior_loss_weight != 0: |
378 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 403 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
379 | target, target_prior = torch.chunk(target, 2, dim=0) | 404 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
405 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
380 | 406 | ||
381 | # Compute instance loss | 407 | # Compute instance loss |
382 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 408 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
383 | 409 | ||
384 | # Compute prior loss | 410 | # Compute prior loss |
385 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 411 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
386 | 412 | ||
387 | # Add the prior loss to the instance loss. | 413 | # Add the prior loss to the instance loss. |
388 | loss = loss + prior_loss_weight * prior_loss | 414 | loss = loss + prior_loss_weight * prior_loss |
389 | else: | 415 | else: |
390 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 416 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
391 | 417 | ||
392 | loss = loss.mean([1, 2, 3]) | 418 | loss = loss.mean([1, 2, 3]) |
419 | else: | ||
420 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | ||
421 | rec_latent /= vae.config.scaling_factor | ||
422 | rec_latent = rec_latent.to(dtype=vae.dtype) | ||
423 | rec = vae.decode(rec_latent).sample | ||
424 | loss = 1 - disc.get_score(rec) | ||
425 | del rec_latent, rec | ||
393 | 426 | ||
394 | if min_snr_gamma != 0: | 427 | if min_snr_gamma != 0: |
395 | snr = compute_snr(timesteps, noise_scheduler) | 428 | snr = compute_snr(timesteps, noise_scheduler) |
@@ -645,6 +678,7 @@ def train( | |||
645 | guidance_scale: float = 0.0, | 678 | guidance_scale: float = 0.0, |
646 | prior_loss_weight: float = 1.0, | 679 | prior_loss_weight: float = 1.0, |
647 | offset_noise_strength: float = 0.15, | 680 | offset_noise_strength: float = 0.15, |
681 | disc: Optional[ConvNeXtDiscriminator] = None, | ||
648 | min_snr_gamma: int = 5, | 682 | min_snr_gamma: int = 5, |
649 | **kwargs, | 683 | **kwargs, |
650 | ): | 684 | ): |
@@ -676,6 +710,7 @@ def train( | |||
676 | prior_loss_weight, | 710 | prior_loss_weight, |
677 | seed, | 711 | seed, |
678 | offset_noise_strength, | 712 | offset_noise_strength, |
713 | disc, | ||
679 | min_snr_gamma, | 714 | min_snr_gamma, |
680 | ) | 715 | ) |
681 | 716 | ||