summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-11 22:36:05 +0200
committerVolpeon <git@volpeon.ink>2023-04-11 22:36:05 +0200
commitf39286fa5c5840b67dadf8e85f5f5d7ff1414aab (patch)
treeb56956444af8439404bb8eb9b82508ac5e2194b9
parentStore sample images in Tensorboard as well (diff)
downloadtextual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.tar.gz
textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.tar.bz2
textual-inversion-diff-f39286fa5c5840b67dadf8e85f5f5d7ff1414aab.zip
Experimental convnext discriminator support
-rw-r--r--models/convnext/discriminator.py35
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py2
-rw-r--r--train_ti.py14
-rw-r--r--training/functional.py83
4 files changed, 109 insertions, 25 deletions
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py
new file mode 100644
index 0000000..7dbbe3a
--- /dev/null
+++ b/models/convnext/discriminator.py
@@ -0,0 +1,35 @@
1import torch
2from timm.models import ConvNeXt
3from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
4
5from torch.nn import functional as F
6
7
8class ConvNeXtDiscriminator():
9 def __init__(self, model: ConvNeXt, input_size: int) -> None:
10 self.net = model
11
12 self.input_size = input_size
13
14 self.img_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1)
15 self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1)
16
17 def get_score(self, img):
18 img_mean = self.img_mean.to(device=img.device, dtype=img.dtype)
19 img_std = self.img_std.to(device=img.device, dtype=img.dtype)
20
21 img = ((img+1.)/2.).sub(img_mean).div(img_std)
22
23 img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True)
24 pred = self.net(img)
25 return torch.softmax(pred, dim=-1)[:, 1]
26
27 def get_all(self, img):
28 img_mean = self.img_mean.to(device=img.device, dtype=img.dtype)
29 img_std = self.img_std.to(device=img.device, dtype=img.dtype)
30
31 img = ((img + 1.) / 2.).sub(img_mean).div(img_std)
32
33 img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True)
34 pred = self.net(img)
35 return pred
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index cfc3208..13ea2ac 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -477,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
477 # 2. Define call parameters 477 # 2. Define call parameters
478 batch_size = len(prompt) 478 batch_size = len(prompt)
479 device = self.execution_device 479 device = self.execution_device
480 num_channels_latents = self.unet.in_channels 480 num_channels_latents = self.unet.config.in_channels
481 do_classifier_free_guidance = guidance_scale > 1.0 481 do_classifier_free_guidance = guidance_scale > 1.0
482 do_self_attention_guidance = sag_scale > 0.0 482 do_self_attention_guidance = sag_scale > 0.0
483 prep_from_image = isinstance(image, PIL.Image.Image) 483 prep_from_image = isinstance(image, PIL.Image.Image)
diff --git a/train_ti.py b/train_ti.py
index d7878cd..082e9b7 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -13,10 +13,12 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from slugify import slugify 15from slugify import slugify
16from timm.models import create_model
16import transformers 17import transformers
17 18
18from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
21from models.convnext.discriminator import ConvNeXtDiscriminator
20from training.functional import train, add_placeholder_tokens, get_models 22from training.functional import train, add_placeholder_tokens, get_models
21from training.strategy.ti import textual_inversion_strategy 23from training.strategy.ti import textual_inversion_strategy
22from training.optimization import get_scheduler 24from training.optimization import get_scheduler
@@ -661,6 +663,17 @@ def main():
661 unet.enable_gradient_checkpointing() 663 unet.enable_gradient_checkpointing()
662 text_encoder.gradient_checkpointing_enable() 664 text_encoder.gradient_checkpointing_enable()
663 665
666 convnext = create_model(
667 "convnext_tiny",
668 pretrained=False,
669 num_classes=3,
670 drop_path_rate=0.0,
671 )
672 convnext.to(accelerator.device, dtype=weight_dtype)
673 convnext.requires_grad_(False)
674 convnext.eval()
675 disc = ConvNeXtDiscriminator(convnext, input_size=384)
676
664 if len(args.alias_tokens) != 0: 677 if len(args.alias_tokens) != 0:
665 alias_placeholder_tokens = args.alias_tokens[::2] 678 alias_placeholder_tokens = args.alias_tokens[::2]
666 alias_initializer_tokens = args.alias_tokens[1::2] 679 alias_initializer_tokens = args.alias_tokens[1::2]
@@ -802,6 +815,7 @@ def main():
802 milestone_checkpoints=not args.no_milestone_checkpoints, 815 milestone_checkpoints=not args.no_milestone_checkpoints,
803 global_step_offset=global_step_offset, 816 global_step_offset=global_step_offset,
804 offset_noise_strength=args.offset_noise_strength, 817 offset_noise_strength=args.offset_noise_strength,
818 disc=disc,
805 # -- 819 # --
806 use_emb_decay=args.use_emb_decay, 820 use_emb_decay=args.use_emb_decay,
807 emb_decay_target=args.emb_decay_target, 821 emb_decay_target=args.emb_decay_target,
diff --git a/training/functional.py b/training/functional.py
index 2f7f837..be39776 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 23from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
24from models.clip.util import get_extended_embeddings 24from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator
26from training.util import AverageMeter 27from training.util import AverageMeter
27from util.slerp import slerp 28from util.slerp import slerp
28 29
@@ -160,7 +161,8 @@ def save_samples(
160 161
161 for tracker in accelerator.trackers: 162 for tracker in accelerator.trackers:
162 if tracker.name == "tensorboard": 163 if tracker.name == "tensorboard":
163 tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") 164 # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC")
165 pass
164 166
165 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) 167 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols)
166 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0]
@@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler):
275 return snr 277 return snr
276 278
277 279
278def make_solid_image(color: float, shape, vae, dtype, device, generator): 280def get_original(
279 img = torch.tensor( 281 noise_scheduler,
280 [[[[color]]]], 282 model_output,
281 dtype=dtype, 283 sample: torch.FloatTensor,
282 device=device 284 timesteps: torch.IntTensor
283 ).expand(1, *shape) 285):
284 img = img * 2 - 1 286 alphas_cumprod = noise_scheduler.alphas_cumprod
285 img = vae.encode(img).latent_dist.sample(generator=generator) 287 sqrt_alphas_cumprod = alphas_cumprod**0.5
286 img *= vae.config.scaling_factor 288 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
287 return img 289
290 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
291 while len(sqrt_alphas_cumprod.shape) < len(sample.shape):
292 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
293 alpha = sqrt_alphas_cumprod.expand(sample.shape)
294
295 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
296 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape):
297 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
298 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape)
299
300 if noise_scheduler.config.prediction_type == "epsilon":
301 pred_original_sample = (sample - sigma * model_output) / alpha
302 elif noise_scheduler.config.prediction_type == "sample":
303 pred_original_sample = model_output
304 elif noise_scheduler.config.prediction_type == "v_prediction":
305 pred_original_sample = alpha * sample - sigma * model_output
306 else:
307 raise ValueError(
308 f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
309 " `v_prediction` for the DDPMScheduler."
310 )
311 return pred_original_sample
288 312
289 313
290def loss_step( 314def loss_step(
@@ -296,6 +320,7 @@ def loss_step(
296 prior_loss_weight: float, 320 prior_loss_weight: float,
297 seed: int, 321 seed: int,
298 offset_noise_strength: float, 322 offset_noise_strength: float,
323 disc: Optional[ConvNeXtDiscriminator],
299 min_snr_gamma: int, 324 min_snr_gamma: int,
300 step: int, 325 step: int,
301 batch: dict[str, Any], 326 batch: dict[str, Any],
@@ -373,23 +398,31 @@ def loss_step(
373 else: 398 else:
374 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 399 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
375 400
376 if guidance_scale == 0 and prior_loss_weight != 0: 401 if disc is None:
377 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 402 if guidance_scale == 0 and prior_loss_weight != 0:
378 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
379 target, target_prior = torch.chunk(target, 2, dim=0) 404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
380 406
381 # Compute instance loss 407 # Compute instance loss
382 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
383 409
384 # Compute prior loss 410 # Compute prior loss
385 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
386 412
387 # Add the prior loss to the instance loss. 413 # Add the prior loss to the instance loss.
388 loss = loss + prior_loss_weight * prior_loss 414 loss = loss + prior_loss_weight * prior_loss
389 else: 415 else:
390 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
391 417
392 loss = loss.mean([1, 2, 3]) 418 loss = loss.mean([1, 2, 3])
419 else:
420 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
421 rec_latent /= vae.config.scaling_factor
422 rec_latent = rec_latent.to(dtype=vae.dtype)
423 rec = vae.decode(rec_latent).sample
424 loss = 1 - disc.get_score(rec)
425 del rec_latent, rec
393 426
394 if min_snr_gamma != 0: 427 if min_snr_gamma != 0:
395 snr = compute_snr(timesteps, noise_scheduler) 428 snr = compute_snr(timesteps, noise_scheduler)
@@ -645,6 +678,7 @@ def train(
645 guidance_scale: float = 0.0, 678 guidance_scale: float = 0.0,
646 prior_loss_weight: float = 1.0, 679 prior_loss_weight: float = 1.0,
647 offset_noise_strength: float = 0.15, 680 offset_noise_strength: float = 0.15,
681 disc: Optional[ConvNeXtDiscriminator] = None,
648 min_snr_gamma: int = 5, 682 min_snr_gamma: int = 5,
649 **kwargs, 683 **kwargs,
650): 684):
@@ -676,6 +710,7 @@ def train(
676 prior_loss_weight, 710 prior_loss_weight,
677 seed, 711 seed,
678 offset_noise_strength, 712 offset_noise_strength,
713 disc,
679 min_snr_gamma, 714 min_snr_gamma,
680 ) 715 )
681 716