diff options
-rw-r--r-- | models/convnext/discriminator.py | 35 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
-rw-r--r-- | train_ti.py | 14 | ||||
-rw-r--r-- | training/functional.py | 83 |
4 files changed, 109 insertions, 25 deletions
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py new file mode 100644 index 0000000..7dbbe3a --- /dev/null +++ b/models/convnext/discriminator.py | |||
@@ -0,0 +1,35 @@ | |||
1 | import torch | ||
2 | from timm.models import ConvNeXt | ||
3 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | ||
4 | |||
5 | from torch.nn import functional as F | ||
6 | |||
7 | |||
8 | class ConvNeXtDiscriminator(): | ||
9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: | ||
10 | self.net = model | ||
11 | |||
12 | self.input_size = input_size | ||
13 | |||
14 | self.img_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1) | ||
15 | self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1) | ||
16 | |||
17 | def get_score(self, img): | ||
18 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | ||
19 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | ||
20 | |||
21 | img = ((img+1.)/2.).sub(img_mean).div(img_std) | ||
22 | |||
23 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | ||
24 | pred = self.net(img) | ||
25 | return torch.softmax(pred, dim=-1)[:, 1] | ||
26 | |||
27 | def get_all(self, img): | ||
28 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | ||
29 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | ||
30 | |||
31 | img = ((img + 1.) / 2.).sub(img_mean).div(img_std) | ||
32 | |||
33 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | ||
34 | pred = self.net(img) | ||
35 | return pred | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cfc3208..13ea2ac 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -477,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
477 | # 2. Define call parameters | 477 | # 2. Define call parameters |
478 | batch_size = len(prompt) | 478 | batch_size = len(prompt) |
479 | device = self.execution_device | 479 | device = self.execution_device |
480 | num_channels_latents = self.unet.in_channels | 480 | num_channels_latents = self.unet.config.in_channels |
481 | do_classifier_free_guidance = guidance_scale > 1.0 | 481 | do_classifier_free_guidance = guidance_scale > 1.0 |
482 | do_self_attention_guidance = sag_scale > 0.0 | 482 | do_self_attention_guidance = sag_scale > 0.0 |
483 | prep_from_image = isinstance(image, PIL.Image.Image) | 483 | prep_from_image = isinstance(image, PIL.Image.Image) |
diff --git a/train_ti.py b/train_ti.py index d7878cd..082e9b7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -13,10 +13,12 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from slugify import slugify | 15 | from slugify import slugify |
16 | from timm.models import create_model | ||
16 | import transformers | 17 | import transformers |
17 | 18 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
20 | from training.functional import train, add_placeholder_tokens, get_models | 22 | from training.functional import train, add_placeholder_tokens, get_models |
21 | from training.strategy.ti import textual_inversion_strategy | 23 | from training.strategy.ti import textual_inversion_strategy |
22 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
@@ -661,6 +663,17 @@ def main(): | |||
661 | unet.enable_gradient_checkpointing() | 663 | unet.enable_gradient_checkpointing() |
662 | text_encoder.gradient_checkpointing_enable() | 664 | text_encoder.gradient_checkpointing_enable() |
663 | 665 | ||
666 | convnext = create_model( | ||
667 | "convnext_tiny", | ||
668 | pretrained=False, | ||
669 | num_classes=3, | ||
670 | drop_path_rate=0.0, | ||
671 | ) | ||
672 | convnext.to(accelerator.device, dtype=weight_dtype) | ||
673 | convnext.requires_grad_(False) | ||
674 | convnext.eval() | ||
675 | disc = ConvNeXtDiscriminator(convnext, input_size=384) | ||
676 | |||
664 | if len(args.alias_tokens) != 0: | 677 | if len(args.alias_tokens) != 0: |
665 | alias_placeholder_tokens = args.alias_tokens[::2] | 678 | alias_placeholder_tokens = args.alias_tokens[::2] |
666 | alias_initializer_tokens = args.alias_tokens[1::2] | 679 | alias_initializer_tokens = args.alias_tokens[1::2] |
@@ -802,6 +815,7 @@ def main(): | |||
802 | milestone_checkpoints=not args.no_milestone_checkpoints, | 815 | milestone_checkpoints=not args.no_milestone_checkpoints, |
803 | global_step_offset=global_step_offset, | 816 | global_step_offset=global_step_offset, |
804 | offset_noise_strength=args.offset_noise_strength, | 817 | offset_noise_strength=args.offset_noise_strength, |
818 | disc=disc, | ||
805 | # -- | 819 | # -- |
806 | use_emb_decay=args.use_emb_decay, | 820 | use_emb_decay=args.use_emb_decay, |
807 | emb_decay_target=args.emb_decay_target, | 821 | emb_decay_target=args.emb_decay_target, |
diff --git a/training/functional.py b/training/functional.py index 2f7f837..be39776 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings |
25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
26 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
26 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
27 | from util.slerp import slerp | 28 | from util.slerp import slerp |
28 | 29 | ||
@@ -160,7 +161,8 @@ def save_samples( | |||
160 | 161 | ||
161 | for tracker in accelerator.trackers: | 162 | for tracker in accelerator.trackers: |
162 | if tracker.name == "tensorboard": | 163 | if tracker.name == "tensorboard": |
163 | tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 164 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") |
165 | pass | ||
164 | 166 | ||
165 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
166 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] |
@@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler): | |||
275 | return snr | 277 | return snr |
276 | 278 | ||
277 | 279 | ||
278 | def make_solid_image(color: float, shape, vae, dtype, device, generator): | 280 | def get_original( |
279 | img = torch.tensor( | 281 | noise_scheduler, |
280 | [[[[color]]]], | 282 | model_output, |
281 | dtype=dtype, | 283 | sample: torch.FloatTensor, |
282 | device=device | 284 | timesteps: torch.IntTensor |
283 | ).expand(1, *shape) | 285 | ): |
284 | img = img * 2 - 1 | 286 | alphas_cumprod = noise_scheduler.alphas_cumprod |
285 | img = vae.encode(img).latent_dist.sample(generator=generator) | 287 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
286 | img *= vae.config.scaling_factor | 288 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
287 | return img | 289 | |
290 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
291 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | ||
292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
293 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | ||
294 | |||
295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
296 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | ||
297 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | ||
298 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | ||
299 | |||
300 | if noise_scheduler.config.prediction_type == "epsilon": | ||
301 | pred_original_sample = (sample - sigma * model_output) / alpha | ||
302 | elif noise_scheduler.config.prediction_type == "sample": | ||
303 | pred_original_sample = model_output | ||
304 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
305 | pred_original_sample = alpha * sample - sigma * model_output | ||
306 | else: | ||
307 | raise ValueError( | ||
308 | f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | ||
309 | " `v_prediction` for the DDPMScheduler." | ||
310 | ) | ||
311 | return pred_original_sample | ||
288 | 312 | ||
289 | 313 | ||
290 | def loss_step( | 314 | def loss_step( |
@@ -296,6 +320,7 @@ def loss_step( | |||
296 | prior_loss_weight: float, | 320 | prior_loss_weight: float, |
297 | seed: int, | 321 | seed: int, |
298 | offset_noise_strength: float, | 322 | offset_noise_strength: float, |
323 | disc: Optional[ConvNeXtDiscriminator], | ||
299 | min_snr_gamma: int, | 324 | min_snr_gamma: int, |
300 | step: int, | 325 | step: int, |
301 | batch: dict[str, Any], | 326 | batch: dict[str, Any], |
@@ -373,23 +398,31 @@ def loss_step( | |||
373 | else: | 398 | else: |
374 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 399 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
375 | 400 | ||
376 | if guidance_scale == 0 and prior_loss_weight != 0: | 401 | if disc is None: |
377 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 402 | if guidance_scale == 0 and prior_loss_weight != 0: |
378 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 403 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
379 | target, target_prior = torch.chunk(target, 2, dim=0) | 404 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
405 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
380 | 406 | ||
381 | # Compute instance loss | 407 | # Compute instance loss |
382 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 408 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
383 | 409 | ||
384 | # Compute prior loss | 410 | # Compute prior loss |
385 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 411 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
386 | 412 | ||
387 | # Add the prior loss to the instance loss. | 413 | # Add the prior loss to the instance loss. |
388 | loss = loss + prior_loss_weight * prior_loss | 414 | loss = loss + prior_loss_weight * prior_loss |
389 | else: | 415 | else: |
390 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 416 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
391 | 417 | ||
392 | loss = loss.mean([1, 2, 3]) | 418 | loss = loss.mean([1, 2, 3]) |
419 | else: | ||
420 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | ||
421 | rec_latent /= vae.config.scaling_factor | ||
422 | rec_latent = rec_latent.to(dtype=vae.dtype) | ||
423 | rec = vae.decode(rec_latent).sample | ||
424 | loss = 1 - disc.get_score(rec) | ||
425 | del rec_latent, rec | ||
393 | 426 | ||
394 | if min_snr_gamma != 0: | 427 | if min_snr_gamma != 0: |
395 | snr = compute_snr(timesteps, noise_scheduler) | 428 | snr = compute_snr(timesteps, noise_scheduler) |
@@ -645,6 +678,7 @@ def train( | |||
645 | guidance_scale: float = 0.0, | 678 | guidance_scale: float = 0.0, |
646 | prior_loss_weight: float = 1.0, | 679 | prior_loss_weight: float = 1.0, |
647 | offset_noise_strength: float = 0.15, | 680 | offset_noise_strength: float = 0.15, |
681 | disc: Optional[ConvNeXtDiscriminator] = None, | ||
648 | min_snr_gamma: int = 5, | 682 | min_snr_gamma: int = 5, |
649 | **kwargs, | 683 | **kwargs, |
650 | ): | 684 | ): |
@@ -676,6 +710,7 @@ def train( | |||
676 | prior_loss_weight, | 710 | prior_loss_weight, |
677 | seed, | 711 | seed, |
678 | offset_noise_strength, | 712 | offset_noise_strength, |
713 | disc, | ||
679 | min_snr_gamma, | 714 | min_snr_gamma, |
680 | ) | 715 | ) |
681 | 716 | ||