diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-11 22:36:05 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-11 22:36:05 +0200 |
| commit | f39286fa5c5840b67dadf8e85f5f5d7ff1414aab (patch) | |
| tree | b56956444af8439404bb8eb9b82508ac5e2194b9 | |
| parent | Store sample images in Tensorboard as well (diff) | |
| download | textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.tar.gz textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.tar.bz2 textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.zip | |
Experimental convnext discriminator support
| -rw-r--r-- | models/convnext/discriminator.py | 35 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 14 | ||||
| -rw-r--r-- | training/functional.py | 83 |
4 files changed, 109 insertions, 25 deletions
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py new file mode 100644 index 0000000..7dbbe3a --- /dev/null +++ b/models/convnext/discriminator.py | |||
| @@ -0,0 +1,35 @@ | |||
| 1 | import torch | ||
| 2 | from timm.models import ConvNeXt | ||
| 3 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | ||
| 4 | |||
| 5 | from torch.nn import functional as F | ||
| 6 | |||
| 7 | |||
| 8 | class ConvNeXtDiscriminator(): | ||
| 9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: | ||
| 10 | self.net = model | ||
| 11 | |||
| 12 | self.input_size = input_size | ||
| 13 | |||
| 14 | self.img_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1) | ||
| 15 | self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1) | ||
| 16 | |||
| 17 | def get_score(self, img): | ||
| 18 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | ||
| 19 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | ||
| 20 | |||
| 21 | img = ((img+1.)/2.).sub(img_mean).div(img_std) | ||
| 22 | |||
| 23 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | ||
| 24 | pred = self.net(img) | ||
| 25 | return torch.softmax(pred, dim=-1)[:, 1] | ||
| 26 | |||
| 27 | def get_all(self, img): | ||
| 28 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | ||
| 29 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | ||
| 30 | |||
| 31 | img = ((img + 1.) / 2.).sub(img_mean).div(img_std) | ||
| 32 | |||
| 33 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | ||
| 34 | pred = self.net(img) | ||
| 35 | return pred | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cfc3208..13ea2ac 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -477,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 477 | # 2. Define call parameters | 477 | # 2. Define call parameters |
| 478 | batch_size = len(prompt) | 478 | batch_size = len(prompt) |
| 479 | device = self.execution_device | 479 | device = self.execution_device |
| 480 | num_channels_latents = self.unet.in_channels | 480 | num_channels_latents = self.unet.config.in_channels |
| 481 | do_classifier_free_guidance = guidance_scale > 1.0 | 481 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 482 | do_self_attention_guidance = sag_scale > 0.0 | 482 | do_self_attention_guidance = sag_scale > 0.0 |
| 483 | prep_from_image = isinstance(image, PIL.Image.Image) | 483 | prep_from_image = isinstance(image, PIL.Image.Image) |
diff --git a/train_ti.py b/train_ti.py index d7878cd..082e9b7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -13,10 +13,12 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from slugify import slugify | 15 | from slugify import slugify |
| 16 | from timm.models import create_model | ||
| 16 | import transformers | 17 | import transformers |
| 17 | 18 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
| 20 | from training.functional import train, add_placeholder_tokens, get_models | 22 | from training.functional import train, add_placeholder_tokens, get_models |
| 21 | from training.strategy.ti import textual_inversion_strategy | 23 | from training.strategy.ti import textual_inversion_strategy |
| 22 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
| @@ -661,6 +663,17 @@ def main(): | |||
| 661 | unet.enable_gradient_checkpointing() | 663 | unet.enable_gradient_checkpointing() |
| 662 | text_encoder.gradient_checkpointing_enable() | 664 | text_encoder.gradient_checkpointing_enable() |
| 663 | 665 | ||
| 666 | convnext = create_model( | ||
| 667 | "convnext_tiny", | ||
| 668 | pretrained=False, | ||
| 669 | num_classes=3, | ||
| 670 | drop_path_rate=0.0, | ||
| 671 | ) | ||
| 672 | convnext.to(accelerator.device, dtype=weight_dtype) | ||
| 673 | convnext.requires_grad_(False) | ||
| 674 | convnext.eval() | ||
| 675 | disc = ConvNeXtDiscriminator(convnext, input_size=384) | ||
| 676 | |||
| 664 | if len(args.alias_tokens) != 0: | 677 | if len(args.alias_tokens) != 0: |
| 665 | alias_placeholder_tokens = args.alias_tokens[::2] | 678 | alias_placeholder_tokens = args.alias_tokens[::2] |
| 666 | alias_initializer_tokens = args.alias_tokens[1::2] | 679 | alias_initializer_tokens = args.alias_tokens[1::2] |
| @@ -802,6 +815,7 @@ def main(): | |||
| 802 | milestone_checkpoints=not args.no_milestone_checkpoints, | 815 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 803 | global_step_offset=global_step_offset, | 816 | global_step_offset=global_step_offset, |
| 804 | offset_noise_strength=args.offset_noise_strength, | 817 | offset_noise_strength=args.offset_noise_strength, |
| 818 | disc=disc, | ||
| 805 | # -- | 819 | # -- |
| 806 | use_emb_decay=args.use_emb_decay, | 820 | use_emb_decay=args.use_emb_decay, |
| 807 | emb_decay_target=args.emb_decay_target, | 821 | emb_decay_target=args.emb_decay_target, |
diff --git a/training/functional.py b/training/functional.py index 2f7f837..be39776 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings |
| 25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 26 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
| 26 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
| 27 | from util.slerp import slerp | 28 | from util.slerp import slerp |
| 28 | 29 | ||
| @@ -160,7 +161,8 @@ def save_samples( | |||
| 160 | 161 | ||
| 161 | for tracker in accelerator.trackers: | 162 | for tracker in accelerator.trackers: |
| 162 | if tracker.name == "tensorboard": | 163 | if tracker.name == "tensorboard": |
| 163 | tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 164 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") |
| 165 | pass | ||
| 164 | 166 | ||
| 165 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
| 166 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] |
| @@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler): | |||
| 275 | return snr | 277 | return snr |
| 276 | 278 | ||
| 277 | 279 | ||
| 278 | def make_solid_image(color: float, shape, vae, dtype, device, generator): | 280 | def get_original( |
| 279 | img = torch.tensor( | 281 | noise_scheduler, |
| 280 | [[[[color]]]], | 282 | model_output, |
| 281 | dtype=dtype, | 283 | sample: torch.FloatTensor, |
| 282 | device=device | 284 | timesteps: torch.IntTensor |
| 283 | ).expand(1, *shape) | 285 | ): |
| 284 | img = img * 2 - 1 | 286 | alphas_cumprod = noise_scheduler.alphas_cumprod |
| 285 | img = vae.encode(img).latent_dist.sample(generator=generator) | 287 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| 286 | img *= vae.config.scaling_factor | 288 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 287 | return img | 289 | |
| 290 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
| 291 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | ||
| 292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
| 293 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | ||
| 294 | |||
| 295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
| 296 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | ||
| 297 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | ||
| 298 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | ||
| 299 | |||
| 300 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 301 | pred_original_sample = (sample - sigma * model_output) / alpha | ||
| 302 | elif noise_scheduler.config.prediction_type == "sample": | ||
| 303 | pred_original_sample = model_output | ||
| 304 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 305 | pred_original_sample = alpha * sample - sigma * model_output | ||
| 306 | else: | ||
| 307 | raise ValueError( | ||
| 308 | f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | ||
| 309 | " `v_prediction` for the DDPMScheduler." | ||
| 310 | ) | ||
| 311 | return pred_original_sample | ||
| 288 | 312 | ||
| 289 | 313 | ||
| 290 | def loss_step( | 314 | def loss_step( |
| @@ -296,6 +320,7 @@ def loss_step( | |||
| 296 | prior_loss_weight: float, | 320 | prior_loss_weight: float, |
| 297 | seed: int, | 321 | seed: int, |
| 298 | offset_noise_strength: float, | 322 | offset_noise_strength: float, |
| 323 | disc: Optional[ConvNeXtDiscriminator], | ||
| 299 | min_snr_gamma: int, | 324 | min_snr_gamma: int, |
| 300 | step: int, | 325 | step: int, |
| 301 | batch: dict[str, Any], | 326 | batch: dict[str, Any], |
| @@ -373,23 +398,31 @@ def loss_step( | |||
| 373 | else: | 398 | else: |
| 374 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 399 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 375 | 400 | ||
| 376 | if guidance_scale == 0 and prior_loss_weight != 0: | 401 | if disc is None: |
| 377 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 402 | if guidance_scale == 0 and prior_loss_weight != 0: |
| 378 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 403 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 379 | target, target_prior = torch.chunk(target, 2, dim=0) | 404 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 405 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 380 | 406 | ||
| 381 | # Compute instance loss | 407 | # Compute instance loss |
| 382 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 408 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 383 | 409 | ||
| 384 | # Compute prior loss | 410 | # Compute prior loss |
| 385 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 411 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
| 386 | 412 | ||
| 387 | # Add the prior loss to the instance loss. | 413 | # Add the prior loss to the instance loss. |
| 388 | loss = loss + prior_loss_weight * prior_loss | 414 | loss = loss + prior_loss_weight * prior_loss |
| 389 | else: | 415 | else: |
| 390 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 416 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 391 | 417 | ||
| 392 | loss = loss.mean([1, 2, 3]) | 418 | loss = loss.mean([1, 2, 3]) |
| 419 | else: | ||
| 420 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | ||
| 421 | rec_latent /= vae.config.scaling_factor | ||
| 422 | rec_latent = rec_latent.to(dtype=vae.dtype) | ||
| 423 | rec = vae.decode(rec_latent).sample | ||
| 424 | loss = 1 - disc.get_score(rec) | ||
| 425 | del rec_latent, rec | ||
| 393 | 426 | ||
| 394 | if min_snr_gamma != 0: | 427 | if min_snr_gamma != 0: |
| 395 | snr = compute_snr(timesteps, noise_scheduler) | 428 | snr = compute_snr(timesteps, noise_scheduler) |
| @@ -645,6 +678,7 @@ def train( | |||
| 645 | guidance_scale: float = 0.0, | 678 | guidance_scale: float = 0.0, |
| 646 | prior_loss_weight: float = 1.0, | 679 | prior_loss_weight: float = 1.0, |
| 647 | offset_noise_strength: float = 0.15, | 680 | offset_noise_strength: float = 0.15, |
| 681 | disc: Optional[ConvNeXtDiscriminator] = None, | ||
| 648 | min_snr_gamma: int = 5, | 682 | min_snr_gamma: int = 5, |
| 649 | **kwargs, | 683 | **kwargs, |
| 650 | ): | 684 | ): |
| @@ -676,6 +710,7 @@ def train( | |||
| 676 | prior_loss_weight, | 710 | prior_loss_weight, |
| 677 | seed, | 711 | seed, |
| 678 | offset_noise_strength, | 712 | offset_noise_strength, |
| 713 | disc, | ||
| 679 | min_snr_gamma, | 714 | min_snr_gamma, |
| 680 | ) | 715 | ) |
| 681 | 716 | ||
