diff options
| -rw-r--r-- | models/clip/embeddings.py | 11 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 74 | ||||
| -rw-r--r-- | train_dreambooth.py | 228 | ||||
| -rw-r--r-- | train_ti.py | 51 | ||||
| -rw-r--r-- | training/lr.py | 6 | ||||
| -rw-r--r-- | training/optimization.py | 2 |
6 files changed, 226 insertions, 146 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..8602142 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe | |||
| 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
| 121 | text_encoder.text_model.embeddings = text_embeddings | 121 | text_encoder.text_model.embeddings = text_embeddings |
| 122 | return text_embeddings | 122 | return text_embeddings |
| 123 | |||
| 124 | |||
| 125 | def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: | ||
| 126 | text_encoder.text_model.embeddings.make_permanent() | ||
| 127 | |||
| 128 | text_embeddings = CLIPTextEmbeddings(text_encoder.config) | ||
| 129 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding | ||
| 130 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
| 131 | text_encoder.text_model.embeddings = text_embeddings | ||
| 132 | |||
| 133 | return text_embeddings | ||
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 37d69a9..ed9774e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -1,11 +1,54 @@ | |||
| 1 | import copy | 1 | import copy |
| 2 | from typing import NamedTuple, Union | 2 | from typing import NamedTuple, Union, Literal |
| 3 | 3 | ||
| 4 | import numpy as np | 4 | import numpy as np |
| 5 | 5 | ||
| 6 | from transformers import CLIPTokenizer | 6 | from transformers import CLIPTokenizer |
| 7 | 7 | ||
| 8 | 8 | ||
| 9 | def shuffle_all(tokens: list[int]): | ||
| 10 | if len(tokens) >= 2: | ||
| 11 | tokens = copy.copy(tokens) | ||
| 12 | np.random.shuffle(tokens) | ||
| 13 | return tokens | ||
| 14 | |||
| 15 | |||
| 16 | def shuffle_leading(tokens: list[int]): | ||
| 17 | if len(tokens) >= 3: | ||
| 18 | subtokens = tokens[:-1] | ||
| 19 | np.random.shuffle(subtokens) | ||
| 20 | tokens = subtokens + tokens[-1:] | ||
| 21 | return tokens | ||
| 22 | |||
| 23 | |||
| 24 | def shuffle_trailing(tokens: list[int]): | ||
| 25 | if len(tokens) >= 3: | ||
| 26 | subtokens = tokens[1:] | ||
| 27 | np.random.shuffle(subtokens) | ||
| 28 | tokens = tokens[:1] + subtokens | ||
| 29 | return tokens | ||
| 30 | |||
| 31 | |||
| 32 | def shuffle_between(tokens: list[int]): | ||
| 33 | if len(tokens) >= 4: | ||
| 34 | subtokens = tokens[1:-1] | ||
| 35 | np.random.shuffle(subtokens) | ||
| 36 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
| 37 | return tokens | ||
| 38 | |||
| 39 | |||
| 40 | def shuffle_none(tokens: list[int]): | ||
| 41 | return tokens | ||
| 42 | |||
| 43 | |||
| 44 | def shuffle_auto(tokens: list[int]): | ||
| 45 | if len(tokens) >= 4: | ||
| 46 | return shuffle_between(tokens) | ||
| 47 | if len(tokens) >= 3: | ||
| 48 | return shuffle_trailing(tokens) | ||
| 49 | return shuffle_all(tokens) | ||
| 50 | |||
| 51 | |||
| 9 | class MultiCLIPTokenizerItem(NamedTuple): | 52 | class MultiCLIPTokenizerItem(NamedTuple): |
| 10 | token: str | 53 | token: str |
| 11 | ids: list[int] | 54 | ids: list[int] |
| @@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 15 | def __init__(self, *args, **kwargs): | 58 | def __init__(self, *args, **kwargs): |
| 16 | super().__init__(*args, **kwargs) | 59 | super().__init__(*args, **kwargs) |
| 17 | self.token_map: dict[int, list[int]] = {} | 60 | self.token_map: dict[int, list[int]] = {} |
| 18 | self.vector_shuffle = False | 61 | self.vector_shuffle = shuffle_none |
| 19 | 62 | ||
| 20 | def set_use_vector_shuffle(self, enable: bool): | 63 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): |
| 21 | self.vector_shuffle = enable | 64 | if algorithm == "leading": |
| 65 | self.vector_shuffle = shuffle_leading | ||
| 66 | elif algorithm == "trailing": | ||
| 67 | self.vector_shuffle = shuffle_trailing | ||
| 68 | elif algorithm == "between": | ||
| 69 | self.vector_shuffle = shuffle_between | ||
| 70 | elif algorithm == "auto": | ||
| 71 | self.vector_shuffle = shuffle_auto | ||
| 72 | elif algorithm == True or algorithm == "all": | ||
| 73 | self.vector_shuffle = shuffle_all | ||
| 74 | else: | ||
| 75 | self.vector_shuffle = shuffle_none | ||
| 22 | 76 | ||
| 23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 77 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
| 24 | if isinstance(new_tokens, list): | 78 | if isinstance(new_tokens, list): |
| @@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 47 | return MultiCLIPTokenizerItem(new_tokens, ids) | 101 | return MultiCLIPTokenizerItem(new_tokens, ids) |
| 48 | 102 | ||
| 49 | def expand_id(self, id: int): | 103 | def expand_id(self, id: int): |
| 50 | if id in self.token_map: | 104 | return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] |
| 51 | tokens = self.token_map[id] | ||
| 52 | |||
| 53 | if self.vector_shuffle and len(tokens) > 2: | ||
| 54 | subtokens = tokens[1:-1] | ||
| 55 | np.random.shuffle(subtokens) | ||
| 56 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
| 57 | |||
| 58 | return tokens | ||
| 59 | else: | ||
| 60 | return [id] | ||
| 61 | 105 | ||
| 62 | def expand_ids(self, ids: list[int]): | 106 | def expand_ids(self, ids: list[int]): |
| 63 | return [ | 107 | return [ |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1ebcfe3..b07de31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -3,7 +3,6 @@ import itertools | |||
| 3 | import math | 3 | import math |
| 4 | import datetime | 4 | import datetime |
| 5 | import logging | 5 | import logging |
| 6 | import json | ||
| 7 | from pathlib import Path | 6 | from pathlib import Path |
| 8 | 7 | ||
| 9 | import torch | 8 | import torch |
| @@ -15,18 +14,21 @@ from accelerate.logging import get_logger | |||
| 15 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 17 | import matplotlib.pyplot as plt | ||
| 18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
| 19 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
| 20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 21 | from slugify import slugify | 21 | from slugify import slugify |
| 22 | 22 | ||
| 23 | from common import load_text_embeddings, load_config | 23 | from common import load_config, load_embeddings_from_dir |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings | ||
| 29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
| 31 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 30 | 32 | ||
| 31 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
| 32 | 34 | ||
| @@ -106,6 +108,12 @@ def parse_args(): | |||
| 106 | help="Tag dropout probability.", | 108 | help="Tag dropout probability.", |
| 107 | ) | 109 | ) |
| 108 | parser.add_argument( | 110 | parser.add_argument( |
| 111 | "--vector_shuffle", | ||
| 112 | type=str, | ||
| 113 | default="auto", | ||
| 114 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
| 115 | ) | ||
| 116 | parser.add_argument( | ||
| 109 | "--num_class_images", | 117 | "--num_class_images", |
| 110 | type=int, | 118 | type=int, |
| 111 | default=1, | 119 | default=1, |
| @@ -193,13 +201,12 @@ def parse_args(): | |||
| 193 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 201 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| 194 | ) | 202 | ) |
| 195 | parser.add_argument( | 203 | parser.add_argument( |
| 196 | "--learning_rate_unet", | 204 | "--find_lr", |
| 197 | type=float, | 205 | action="store_true", |
| 198 | default=2e-6, | 206 | help="Automatically find a learning rate (no training).", |
| 199 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 200 | ) | 207 | ) |
| 201 | parser.add_argument( | 208 | parser.add_argument( |
| 202 | "--learning_rate_text", | 209 | "--learning_rate", |
| 203 | type=float, | 210 | type=float, |
| 204 | default=2e-6, | 211 | default=2e-6, |
| 205 | help="Initial learning rate (after the potential warmup period) to use.", | 212 | help="Initial learning rate (after the potential warmup period) to use.", |
| @@ -546,9 +553,9 @@ def main(): | |||
| 546 | 553 | ||
| 547 | # Load the tokenizer and add the placeholder token as a additional special token | 554 | # Load the tokenizer and add the placeholder token as a additional special token |
| 548 | if args.tokenizer_name: | 555 | if args.tokenizer_name: |
| 549 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 556 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 550 | elif args.pretrained_model_name_or_path: | 557 | elif args.pretrained_model_name_or_path: |
| 551 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 558 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 552 | 559 | ||
| 553 | # Load models and create wrapper for stable diffusion | 560 | # Load models and create wrapper for stable diffusion |
| 554 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 561 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -558,6 +565,8 @@ def main(): | |||
| 558 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 559 | args.pretrained_model_name_or_path, subfolder='scheduler') | 566 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 560 | 567 | ||
| 568 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 569 | |||
| 561 | vae.enable_slicing() | 570 | vae.enable_slicing() |
| 562 | vae.set_use_memory_efficient_attention_xformers(True) | 571 | vae.set_use_memory_efficient_attention_xformers(True) |
| 563 | unet.set_use_memory_efficient_attention_xformers(True) | 572 | unet.set_use_memory_efficient_attention_xformers(True) |
| @@ -576,46 +585,42 @@ def main(): | |||
| 576 | device=accelerator.device | 585 | device=accelerator.device |
| 577 | ) | 586 | ) |
| 578 | 587 | ||
| 579 | # Freeze text_encoder and vae | 588 | embeddings = patch_managed_embeddings(text_encoder) |
| 580 | vae.requires_grad_(False) | ||
| 581 | 589 | ||
| 582 | if args.embeddings_dir is not None: | 590 | if args.embeddings_dir is not None: |
| 583 | embeddings_dir = Path(args.embeddings_dir) | 591 | embeddings_dir = Path(args.embeddings_dir) |
| 584 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 592 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 585 | raise ValueError("--embeddings_dir must point to an existing directory") | 593 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 586 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 594 | |
| 587 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 595 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 596 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | ||
| 588 | 597 | ||
| 589 | if len(args.placeholder_token) != 0: | 598 | if len(args.placeholder_token) != 0: |
| 590 | # Convert the initializer_token, placeholder_token to ids | 599 | # Convert the initializer_token, placeholder_token to ids |
| 591 | initializer_token_ids = torch.stack([ | 600 | initializer_token_ids = [ |
| 592 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 601 | tokenizer.encode(token, add_special_tokens=False) |
| 593 | for token in args.initializer_token | 602 | for token in args.initializer_token |
| 594 | ]) | 603 | ] |
| 595 | |||
| 596 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
| 597 | print(f"Added {num_added_tokens} new tokens.") | ||
| 598 | 604 | ||
| 599 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 605 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
| 606 | embeddings.resize(len(tokenizer)) | ||
| 600 | 607 | ||
| 601 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 608 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
| 602 | text_encoder.resize_token_embeddings(len(tokenizer)) | 609 | embeddings.add_embed(new_token.ids, init_ids) |
| 603 | 610 | ||
| 604 | token_embeds = text_encoder.get_input_embeddings().weight.data | 611 | print(f"Added {len(new_tokens)} new tokens.") |
| 605 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
| 606 | |||
| 607 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
| 608 | token_embeds[token_id] = embeddings | ||
| 609 | else: | 612 | else: |
| 610 | placeholder_token_id = [] | 613 | placeholder_token_id = [] |
| 611 | 614 | ||
| 615 | vae.requires_grad_(False) | ||
| 616 | |||
| 612 | if args.train_text_encoder: | 617 | if args.train_text_encoder: |
| 613 | print(f"Training entire text encoder.") | 618 | print(f"Training entire text encoder.") |
| 619 | |||
| 620 | unpatch_managed_embeddings(text_encoder) | ||
| 614 | else: | 621 | else: |
| 615 | print(f"Training added text embeddings") | 622 | print(f"Training added text embeddings") |
| 616 | 623 | ||
| 617 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | ||
| 618 | |||
| 619 | text_encoder.text_model.encoder.requires_grad_(False) | 624 | text_encoder.text_model.encoder.requires_grad_(False) |
| 620 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 625 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 621 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 626 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| @@ -624,15 +629,14 @@ def main(): | |||
| 624 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 629 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 625 | 630 | ||
| 626 | if args.scale_lr: | 631 | if args.scale_lr: |
| 627 | args.learning_rate_unet = ( | 632 | args.learning_rate = ( |
| 628 | args.learning_rate_unet * args.gradient_accumulation_steps * | 633 | args.learning_rate * args.gradient_accumulation_steps * |
| 629 | args.train_batch_size * accelerator.num_processes | ||
| 630 | ) | ||
| 631 | args.learning_rate_text = ( | ||
| 632 | args.learning_rate_text * args.gradient_accumulation_steps * | ||
| 633 | args.train_batch_size * accelerator.num_processes | 634 | args.train_batch_size * accelerator.num_processes |
| 634 | ) | 635 | ) |
| 635 | 636 | ||
| 637 | if args.find_lr: | ||
| 638 | args.learning_rate = 1e2 | ||
| 639 | |||
| 636 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 640 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 637 | if args.use_8bit_adam: | 641 | if args.use_8bit_adam: |
| 638 | try: | 642 | try: |
| @@ -647,20 +651,19 @@ def main(): | |||
| 647 | if args.train_text_encoder: | 651 | if args.train_text_encoder: |
| 648 | text_encoder_params_to_optimize = text_encoder.parameters() | 652 | text_encoder_params_to_optimize = text_encoder.parameters() |
| 649 | else: | 653 | else: |
| 650 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() | 654 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
| 651 | 655 | ||
| 652 | # Initialize the optimizer | 656 | # Initialize the optimizer |
| 653 | optimizer = optimizer_class( | 657 | optimizer = optimizer_class( |
| 654 | [ | 658 | [ |
| 655 | { | 659 | { |
| 656 | 'params': unet.parameters(), | 660 | 'params': unet.parameters(), |
| 657 | 'lr': args.learning_rate_unet, | ||
| 658 | }, | 661 | }, |
| 659 | { | 662 | { |
| 660 | 'params': text_encoder_params_to_optimize, | 663 | 'params': text_encoder_params_to_optimize, |
| 661 | 'lr': args.learning_rate_text, | ||
| 662 | } | 664 | } |
| 663 | ], | 665 | ], |
| 666 | lr=args.learning_rate, | ||
| 664 | betas=(args.adam_beta1, args.adam_beta2), | 667 | betas=(args.adam_beta1, args.adam_beta2), |
| 665 | weight_decay=args.adam_weight_decay, | 668 | weight_decay=args.adam_weight_decay, |
| 666 | eps=args.adam_epsilon, | 669 | eps=args.adam_epsilon, |
| @@ -824,6 +827,58 @@ def main(): | |||
| 824 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 827 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 825 | val_steps = num_val_steps_per_epoch * num_epochs | 828 | val_steps = num_val_steps_per_epoch * num_epochs |
| 826 | 829 | ||
| 830 | def loop(batch): | ||
| 831 | # Convert images to latent space | ||
| 832 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| 833 | latents = latents * 0.18215 | ||
| 834 | |||
| 835 | # Sample noise that we'll add to the latents | ||
| 836 | noise = torch.randn_like(latents) | ||
| 837 | bsz = latents.shape[0] | ||
| 838 | # Sample a random timestep for each image | ||
| 839 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 840 | (bsz,), device=latents.device) | ||
| 841 | timesteps = timesteps.long() | ||
| 842 | |||
| 843 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 844 | # (this is the forward diffusion process) | ||
| 845 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 846 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 847 | |||
| 848 | # Get the text embedding for conditioning | ||
| 849 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 850 | |||
| 851 | # Predict the noise residual | ||
| 852 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 853 | |||
| 854 | # Get the target for loss depending on the prediction type | ||
| 855 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 856 | target = noise | ||
| 857 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 858 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 859 | else: | ||
| 860 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 861 | |||
| 862 | if args.num_class_images != 0: | ||
| 863 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 864 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 865 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 866 | |||
| 867 | # Compute instance loss | ||
| 868 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 869 | |||
| 870 | # Compute prior loss | ||
| 871 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 872 | |||
| 873 | # Add the prior loss to the instance loss. | ||
| 874 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 875 | else: | ||
| 876 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 877 | |||
| 878 | acc = (model_pred == target).float().mean() | ||
| 879 | |||
| 880 | return loss, acc, bsz | ||
| 881 | |||
| 827 | # We need to initialize the trackers we use, and also store our configuration. | 882 | # We need to initialize the trackers we use, and also store our configuration. |
| 828 | # The trackers initializes automatically on the main process. | 883 | # The trackers initializes automatically on the main process. |
| 829 | if accelerator.is_main_process: | 884 | if accelerator.is_main_process: |
| @@ -836,6 +891,15 @@ def main(): | |||
| 836 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | 891 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
| 837 | accelerator.init_trackers("dreambooth", config=config) | 892 | accelerator.init_trackers("dreambooth", config=config) |
| 838 | 893 | ||
| 894 | if args.find_lr: | ||
| 895 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
| 896 | lr_finder.run(min_lr=1e-4) | ||
| 897 | |||
| 898 | plt.savefig(basepath.joinpath("lr.png")) | ||
| 899 | plt.close() | ||
| 900 | |||
| 901 | quit() | ||
| 902 | |||
| 839 | # Train! | 903 | # Train! |
| 840 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 904 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| 841 | 905 | ||
| @@ -893,58 +957,6 @@ def main(): | |||
| 893 | ) | 957 | ) |
| 894 | global_progress_bar.set_description("Total progress") | 958 | global_progress_bar.set_description("Total progress") |
| 895 | 959 | ||
| 896 | def loop(batch): | ||
| 897 | # Convert images to latent space | ||
| 898 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| 899 | latents = latents * 0.18215 | ||
| 900 | |||
| 901 | # Sample noise that we'll add to the latents | ||
| 902 | noise = torch.randn_like(latents) | ||
| 903 | bsz = latents.shape[0] | ||
| 904 | # Sample a random timestep for each image | ||
| 905 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 906 | (bsz,), device=latents.device) | ||
| 907 | timesteps = timesteps.long() | ||
| 908 | |||
| 909 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 910 | # (this is the forward diffusion process) | ||
| 911 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 912 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 913 | |||
| 914 | # Get the text embedding for conditioning | ||
| 915 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 916 | |||
| 917 | # Predict the noise residual | ||
| 918 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 919 | |||
| 920 | # Get the target for loss depending on the prediction type | ||
| 921 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 922 | target = noise | ||
| 923 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 924 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 925 | else: | ||
| 926 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 927 | |||
| 928 | if args.num_class_images != 0: | ||
| 929 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 930 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 931 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 932 | |||
| 933 | # Compute instance loss | ||
| 934 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 935 | |||
| 936 | # Compute prior loss | ||
| 937 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 938 | |||
| 939 | # Add the prior loss to the instance loss. | ||
| 940 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 941 | else: | ||
| 942 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 943 | |||
| 944 | acc = (model_pred == target).float().mean() | ||
| 945 | |||
| 946 | return loss, acc, bsz | ||
| 947 | |||
| 948 | try: | 960 | try: |
| 949 | for epoch in range(num_epochs): | 961 | for epoch in range(num_epochs): |
| 950 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 962 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -993,8 +1005,7 @@ def main(): | |||
| 993 | "train/acc": avg_acc.avg.item(), | 1005 | "train/acc": avg_acc.avg.item(), |
| 994 | "train/cur_loss": loss.item(), | 1006 | "train/cur_loss": loss.item(), |
| 995 | "train/cur_acc": acc.item(), | 1007 | "train/cur_acc": acc.item(), |
| 996 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1008 | "lr": lr_scheduler.get_last_lr()[0] |
| 997 | "lr/text": lr_scheduler.get_last_lr()[1] | ||
| 998 | } | 1009 | } |
| 999 | if args.use_ema: | 1010 | if args.use_ema: |
| 1000 | logs["ema_decay"] = 1 - ema_unet.decay | 1011 | logs["ema_decay"] = 1 - ema_unet.decay |
| @@ -1011,12 +1022,21 @@ def main(): | |||
| 1011 | unet.eval() | 1022 | unet.eval() |
| 1012 | text_encoder.eval() | 1023 | text_encoder.eval() |
| 1013 | 1024 | ||
| 1025 | cur_loss_val = AverageMeter() | ||
| 1026 | cur_acc_val = AverageMeter() | ||
| 1027 | |||
| 1014 | with torch.inference_mode(): | 1028 | with torch.inference_mode(): |
| 1015 | for step, batch in enumerate(val_dataloader): | 1029 | for step, batch in enumerate(val_dataloader): |
| 1016 | loss, acc, bsz = loop(batch) | 1030 | loss, acc, bsz = loop(batch) |
| 1017 | 1031 | ||
| 1018 | avg_loss_val.update(loss.detach_(), bsz) | 1032 | loss = loss.detach_() |
| 1019 | avg_acc_val.update(acc.detach_(), bsz) | 1033 | acc = acc.detach_() |
| 1034 | |||
| 1035 | cur_loss_val.update(loss, bsz) | ||
| 1036 | cur_acc_val.update(acc, bsz) | ||
| 1037 | |||
| 1038 | avg_loss_val.update(loss, bsz) | ||
| 1039 | avg_acc_val.update(acc, bsz) | ||
| 1020 | 1040 | ||
| 1021 | local_progress_bar.update(1) | 1041 | local_progress_bar.update(1) |
| 1022 | global_progress_bar.update(1) | 1042 | global_progress_bar.update(1) |
| @@ -1029,20 +1049,20 @@ def main(): | |||
| 1029 | } | 1049 | } |
| 1030 | local_progress_bar.set_postfix(**logs) | 1050 | local_progress_bar.set_postfix(**logs) |
| 1031 | 1051 | ||
| 1032 | accelerator.log({ | 1052 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 1033 | "val/loss": avg_loss_val.avg.item(), | 1053 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
| 1034 | "val/acc": avg_acc_val.avg.item(), | 1054 | |
| 1035 | }, step=global_step) | 1055 | accelerator.log(logs, step=global_step) |
| 1036 | 1056 | ||
| 1037 | local_progress_bar.clear() | 1057 | local_progress_bar.clear() |
| 1038 | global_progress_bar.clear() | 1058 | global_progress_bar.clear() |
| 1039 | 1059 | ||
| 1040 | if avg_acc_val.avg.item() > max_acc_val: | ||
| 1041 | accelerator.print( | ||
| 1042 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
| 1043 | max_acc_val = avg_acc_val.avg.item() | ||
| 1044 | |||
| 1045 | if accelerator.is_main_process: | 1060 | if accelerator.is_main_process: |
| 1061 | if avg_acc_val.avg.item() > max_acc_val: | ||
| 1062 | accelerator.print( | ||
| 1063 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
| 1064 | max_acc_val = avg_acc_val.avg.item() | ||
| 1065 | |||
| 1046 | if (epoch + 1) % args.sample_frequency == 0: | 1066 | if (epoch + 1) % args.sample_frequency == 0: |
| 1047 | checkpointer.save_samples(global_step, args.sample_steps) | 1067 | checkpointer.save_samples(global_step, args.sample_steps) |
| 1048 | 1068 | ||
diff --git a/train_ti.py b/train_ti.py index 20a3190..775b918 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,5 +1,4 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | ||
| 3 | import math | 2 | import math |
| 4 | import datetime | 3 | import datetime |
| 5 | import logging | 4 | import logging |
| @@ -156,6 +155,12 @@ def parse_args(): | |||
| 156 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
| 157 | ) | 156 | ) |
| 158 | parser.add_argument( | 157 | parser.add_argument( |
| 158 | "--vector_shuffle", | ||
| 159 | type=str, | ||
| 160 | default="auto", | ||
| 161 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
| 162 | ) | ||
| 163 | parser.add_argument( | ||
| 159 | "--dataloader_num_workers", | 164 | "--dataloader_num_workers", |
| 160 | type=int, | 165 | type=int, |
| 161 | default=0, | 166 | default=0, |
| @@ -245,7 +250,7 @@ def parse_args(): | |||
| 245 | parser.add_argument( | 250 | parser.add_argument( |
| 246 | "--lr_annealing_exp", | 251 | "--lr_annealing_exp", |
| 247 | type=int, | 252 | type=int, |
| 248 | default=2, | 253 | default=1, |
| 249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
| 250 | ) | 255 | ) |
| 251 | parser.add_argument( | 256 | parser.add_argument( |
| @@ -502,20 +507,14 @@ def main(): | |||
| 502 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 507 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 503 | basepath.mkdir(parents=True, exist_ok=True) | 508 | basepath.mkdir(parents=True, exist_ok=True) |
| 504 | 509 | ||
| 505 | if args.find_lr: | 510 | accelerator = Accelerator( |
| 506 | accelerator = Accelerator( | 511 | log_with=LoggerType.TENSORBOARD, |
| 507 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 512 | logging_dir=f"{basepath}", |
| 508 | mixed_precision=args.mixed_precision | 513 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 509 | ) | 514 | mixed_precision=args.mixed_precision |
| 510 | else: | 515 | ) |
| 511 | accelerator = Accelerator( | ||
| 512 | log_with=LoggerType.TENSORBOARD, | ||
| 513 | logging_dir=f"{basepath}", | ||
| 514 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 515 | mixed_precision=args.mixed_precision | ||
| 516 | ) | ||
| 517 | 516 | ||
| 518 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 517 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 519 | 518 | ||
| 520 | args.seed = args.seed or (torch.random.seed() >> 32) | 519 | args.seed = args.seed or (torch.random.seed() >> 32) |
| 521 | set_seed(args.seed) | 520 | set_seed(args.seed) |
| @@ -534,7 +533,7 @@ def main(): | |||
| 534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 533 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 535 | args.pretrained_model_name_or_path, subfolder='scheduler') | 534 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 536 | 535 | ||
| 537 | tokenizer.set_use_vector_shuffle(True) | 536 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 538 | 537 | ||
| 539 | vae.enable_slicing() | 538 | vae.enable_slicing() |
| 540 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) |
| @@ -585,7 +584,7 @@ def main(): | |||
| 585 | ) | 584 | ) |
| 586 | 585 | ||
| 587 | if args.find_lr: | 586 | if args.find_lr: |
| 588 | args.learning_rate = 1e3 | 587 | args.learning_rate = 1e2 |
| 589 | 588 | ||
| 590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 589 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 591 | if args.use_8bit_adam: | 590 | if args.use_8bit_adam: |
| @@ -830,15 +829,6 @@ def main(): | |||
| 830 | 829 | ||
| 831 | return loss, acc, bsz | 830 | return loss, acc, bsz |
| 832 | 831 | ||
| 833 | if args.find_lr: | ||
| 834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
| 835 | lr_finder.run(min_lr=1e-4) | ||
| 836 | |||
| 837 | plt.savefig(basepath.joinpath("lr.png")) | ||
| 838 | plt.close() | ||
| 839 | |||
| 840 | quit() | ||
| 841 | |||
| 842 | # We need to initialize the trackers we use, and also store our configuration. | 832 | # We need to initialize the trackers we use, and also store our configuration. |
| 843 | # The trackers initializes automatically on the main process. | 833 | # The trackers initializes automatically on the main process. |
| 844 | if accelerator.is_main_process: | 834 | if accelerator.is_main_process: |
| @@ -852,6 +842,15 @@ def main(): | |||
| 852 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | 842 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
| 853 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) |
| 854 | 844 | ||
| 845 | if args.find_lr: | ||
| 846 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
| 847 | lr_finder.run(min_lr=1e-4) | ||
| 848 | |||
| 849 | plt.savefig(basepath.joinpath("lr.png")) | ||
| 850 | plt.close() | ||
| 851 | |||
| 852 | quit() | ||
| 853 | |||
| 855 | # Train! | 854 | # Train! |
| 856 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 855 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| 857 | 856 | ||
diff --git a/training/lr.py b/training/lr.py index 0c5ce9e..3abd2f2 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -102,6 +102,12 @@ class LRFinder(): | |||
| 102 | losses.append(loss) | 102 | losses.append(loss) |
| 103 | accs.append(acc) | 103 | accs.append(acc) |
| 104 | 104 | ||
| 105 | self.accelerator.log({ | ||
| 106 | "loss": loss, | ||
| 107 | "acc": acc, | ||
| 108 | "lr": lr, | ||
| 109 | }, step=epoch) | ||
| 110 | |||
| 105 | progress_bar.set_postfix({ | 111 | progress_bar.set_postfix({ |
| 106 | "loss": loss, | 112 | "loss": loss, |
| 107 | "loss/best": best_loss, | 113 | "loss/best": best_loss, |
diff --git a/training/optimization.py b/training/optimization.py index 3340544..a79944f 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -15,7 +15,7 @@ def get_one_cycle_schedule( | |||
| 15 | warmup: Literal["cos", "linear"] = "cos", | 15 | warmup: Literal["cos", "linear"] = "cos", |
| 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
| 17 | warmup_exp: int = 1, | 17 | warmup_exp: int = 1, |
| 18 | annealing_exp: int = 2, | 18 | annealing_exp: int = 1, |
| 19 | min_lr: int = 0.04, | 19 | min_lr: int = 0.04, |
| 20 | mid_point: int = 0.3, | 20 | mid_point: int = 0.3, |
| 21 | last_epoch: int = -1 | 21 | last_epoch: int = -1 |
