diff options
author | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
commit | adc52fb8821a496bc8d78235bf10466b39df03e0 (patch) | |
tree | 8a6337a6ac10cbe76c55514ab559c647e69fb1aa | |
parent | Fixed accuracy calc, other improvements (diff) | |
download | textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2 textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip |
Updates
-rw-r--r-- | models/clip/embeddings.py | 11 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 76 | ||||
-rw-r--r-- | train_dreambooth.py | 228 | ||||
-rw-r--r-- | train_ti.py | 51 | ||||
-rw-r--r-- | training/lr.py | 6 | ||||
-rw-r--r-- | training/optimization.py | 2 |
6 files changed, 227 insertions, 147 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..8602142 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe | |||
120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
121 | text_encoder.text_model.embeddings = text_embeddings | 121 | text_encoder.text_model.embeddings = text_embeddings |
122 | return text_embeddings | 122 | return text_embeddings |
123 | |||
124 | |||
125 | def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: | ||
126 | text_encoder.text_model.embeddings.make_permanent() | ||
127 | |||
128 | text_embeddings = CLIPTextEmbeddings(text_encoder.config) | ||
129 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding | ||
130 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
131 | text_encoder.text_model.embeddings = text_embeddings | ||
132 | |||
133 | return text_embeddings | ||
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 37d69a9..ed9774e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -1,11 +1,54 @@ | |||
1 | import copy | 1 | import copy |
2 | from typing import NamedTuple, Union | 2 | from typing import NamedTuple, Union, Literal |
3 | 3 | ||
4 | import numpy as np | 4 | import numpy as np |
5 | 5 | ||
6 | from transformers import CLIPTokenizer | 6 | from transformers import CLIPTokenizer |
7 | 7 | ||
8 | 8 | ||
9 | def shuffle_all(tokens: list[int]): | ||
10 | if len(tokens) >= 2: | ||
11 | tokens = copy.copy(tokens) | ||
12 | np.random.shuffle(tokens) | ||
13 | return tokens | ||
14 | |||
15 | |||
16 | def shuffle_leading(tokens: list[int]): | ||
17 | if len(tokens) >= 3: | ||
18 | subtokens = tokens[:-1] | ||
19 | np.random.shuffle(subtokens) | ||
20 | tokens = subtokens + tokens[-1:] | ||
21 | return tokens | ||
22 | |||
23 | |||
24 | def shuffle_trailing(tokens: list[int]): | ||
25 | if len(tokens) >= 3: | ||
26 | subtokens = tokens[1:] | ||
27 | np.random.shuffle(subtokens) | ||
28 | tokens = tokens[:1] + subtokens | ||
29 | return tokens | ||
30 | |||
31 | |||
32 | def shuffle_between(tokens: list[int]): | ||
33 | if len(tokens) >= 4: | ||
34 | subtokens = tokens[1:-1] | ||
35 | np.random.shuffle(subtokens) | ||
36 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
37 | return tokens | ||
38 | |||
39 | |||
40 | def shuffle_none(tokens: list[int]): | ||
41 | return tokens | ||
42 | |||
43 | |||
44 | def shuffle_auto(tokens: list[int]): | ||
45 | if len(tokens) >= 4: | ||
46 | return shuffle_between(tokens) | ||
47 | if len(tokens) >= 3: | ||
48 | return shuffle_trailing(tokens) | ||
49 | return shuffle_all(tokens) | ||
50 | |||
51 | |||
9 | class MultiCLIPTokenizerItem(NamedTuple): | 52 | class MultiCLIPTokenizerItem(NamedTuple): |
10 | token: str | 53 | token: str |
11 | ids: list[int] | 54 | ids: list[int] |
@@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
15 | def __init__(self, *args, **kwargs): | 58 | def __init__(self, *args, **kwargs): |
16 | super().__init__(*args, **kwargs) | 59 | super().__init__(*args, **kwargs) |
17 | self.token_map: dict[int, list[int]] = {} | 60 | self.token_map: dict[int, list[int]] = {} |
18 | self.vector_shuffle = False | 61 | self.vector_shuffle = shuffle_none |
19 | 62 | ||
20 | def set_use_vector_shuffle(self, enable: bool): | 63 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): |
21 | self.vector_shuffle = enable | 64 | if algorithm == "leading": |
65 | self.vector_shuffle = shuffle_leading | ||
66 | elif algorithm == "trailing": | ||
67 | self.vector_shuffle = shuffle_trailing | ||
68 | elif algorithm == "between": | ||
69 | self.vector_shuffle = shuffle_between | ||
70 | elif algorithm == "auto": | ||
71 | self.vector_shuffle = shuffle_auto | ||
72 | elif algorithm == True or algorithm == "all": | ||
73 | self.vector_shuffle = shuffle_all | ||
74 | else: | ||
75 | self.vector_shuffle = shuffle_none | ||
22 | 76 | ||
23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 77 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
24 | if isinstance(new_tokens, list): | 78 | if isinstance(new_tokens, list): |
@@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
47 | return MultiCLIPTokenizerItem(new_tokens, ids) | 101 | return MultiCLIPTokenizerItem(new_tokens, ids) |
48 | 102 | ||
49 | def expand_id(self, id: int): | 103 | def expand_id(self, id: int): |
50 | if id in self.token_map: | 104 | return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] |
51 | tokens = self.token_map[id] | ||
52 | |||
53 | if self.vector_shuffle and len(tokens) > 2: | ||
54 | subtokens = tokens[1:-1] | ||
55 | np.random.shuffle(subtokens) | ||
56 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
57 | |||
58 | return tokens | ||
59 | else: | ||
60 | return [id] | ||
61 | 105 | ||
62 | def expand_ids(self, ids: list[int]): | 106 | def expand_ids(self, ids: list[int]): |
63 | return [ | 107 | return [ |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1ebcfe3..b07de31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -3,7 +3,6 @@ import itertools | |||
3 | import math | 3 | import math |
4 | import datetime | 4 | import datetime |
5 | import logging | 5 | import logging |
6 | import json | ||
7 | from pathlib import Path | 6 | from pathlib import Path |
8 | 7 | ||
9 | import torch | 8 | import torch |
@@ -15,18 +14,21 @@ from accelerate.logging import get_logger | |||
15 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
17 | import matplotlib.pyplot as plt | ||
18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
19 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings, load_config | 23 | from common import load_config, load_embeddings_from_dir |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.ti import patch_trainable_embeddings | 27 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings | ||
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
31 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
30 | 32 | ||
31 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
32 | 34 | ||
@@ -106,6 +108,12 @@ def parse_args(): | |||
106 | help="Tag dropout probability.", | 108 | help="Tag dropout probability.", |
107 | ) | 109 | ) |
108 | parser.add_argument( | 110 | parser.add_argument( |
111 | "--vector_shuffle", | ||
112 | type=str, | ||
113 | default="auto", | ||
114 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
115 | ) | ||
116 | parser.add_argument( | ||
109 | "--num_class_images", | 117 | "--num_class_images", |
110 | type=int, | 118 | type=int, |
111 | default=1, | 119 | default=1, |
@@ -193,13 +201,12 @@ def parse_args(): | |||
193 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 201 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
194 | ) | 202 | ) |
195 | parser.add_argument( | 203 | parser.add_argument( |
196 | "--learning_rate_unet", | 204 | "--find_lr", |
197 | type=float, | 205 | action="store_true", |
198 | default=2e-6, | 206 | help="Automatically find a learning rate (no training).", |
199 | help="Initial learning rate (after the potential warmup period) to use.", | ||
200 | ) | 207 | ) |
201 | parser.add_argument( | 208 | parser.add_argument( |
202 | "--learning_rate_text", | 209 | "--learning_rate", |
203 | type=float, | 210 | type=float, |
204 | default=2e-6, | 211 | default=2e-6, |
205 | help="Initial learning rate (after the potential warmup period) to use.", | 212 | help="Initial learning rate (after the potential warmup period) to use.", |
@@ -546,9 +553,9 @@ def main(): | |||
546 | 553 | ||
547 | # Load the tokenizer and add the placeholder token as a additional special token | 554 | # Load the tokenizer and add the placeholder token as a additional special token |
548 | if args.tokenizer_name: | 555 | if args.tokenizer_name: |
549 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 556 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
550 | elif args.pretrained_model_name_or_path: | 557 | elif args.pretrained_model_name_or_path: |
551 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 558 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
552 | 559 | ||
553 | # Load models and create wrapper for stable diffusion | 560 | # Load models and create wrapper for stable diffusion |
554 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 561 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -558,6 +565,8 @@ def main(): | |||
558 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
559 | args.pretrained_model_name_or_path, subfolder='scheduler') | 566 | args.pretrained_model_name_or_path, subfolder='scheduler') |
560 | 567 | ||
568 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
569 | |||
561 | vae.enable_slicing() | 570 | vae.enable_slicing() |
562 | vae.set_use_memory_efficient_attention_xformers(True) | 571 | vae.set_use_memory_efficient_attention_xformers(True) |
563 | unet.set_use_memory_efficient_attention_xformers(True) | 572 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -576,46 +585,42 @@ def main(): | |||
576 | device=accelerator.device | 585 | device=accelerator.device |
577 | ) | 586 | ) |
578 | 587 | ||
579 | # Freeze text_encoder and vae | 588 | embeddings = patch_managed_embeddings(text_encoder) |
580 | vae.requires_grad_(False) | ||
581 | 589 | ||
582 | if args.embeddings_dir is not None: | 590 | if args.embeddings_dir is not None: |
583 | embeddings_dir = Path(args.embeddings_dir) | 591 | embeddings_dir = Path(args.embeddings_dir) |
584 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 592 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
585 | raise ValueError("--embeddings_dir must point to an existing directory") | 593 | raise ValueError("--embeddings_dir must point to an existing directory") |
586 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 594 | |
587 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 595 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
596 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | ||
588 | 597 | ||
589 | if len(args.placeholder_token) != 0: | 598 | if len(args.placeholder_token) != 0: |
590 | # Convert the initializer_token, placeholder_token to ids | 599 | # Convert the initializer_token, placeholder_token to ids |
591 | initializer_token_ids = torch.stack([ | 600 | initializer_token_ids = [ |
592 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 601 | tokenizer.encode(token, add_special_tokens=False) |
593 | for token in args.initializer_token | 602 | for token in args.initializer_token |
594 | ]) | 603 | ] |
595 | |||
596 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
597 | print(f"Added {num_added_tokens} new tokens.") | ||
598 | 604 | ||
599 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 605 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
606 | embeddings.resize(len(tokenizer)) | ||
600 | 607 | ||
601 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 608 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
602 | text_encoder.resize_token_embeddings(len(tokenizer)) | 609 | embeddings.add_embed(new_token.ids, init_ids) |
603 | 610 | ||
604 | token_embeds = text_encoder.get_input_embeddings().weight.data | 611 | print(f"Added {len(new_tokens)} new tokens.") |
605 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
606 | |||
607 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
608 | token_embeds[token_id] = embeddings | ||
609 | else: | 612 | else: |
610 | placeholder_token_id = [] | 613 | placeholder_token_id = [] |
611 | 614 | ||
615 | vae.requires_grad_(False) | ||
616 | |||
612 | if args.train_text_encoder: | 617 | if args.train_text_encoder: |
613 | print(f"Training entire text encoder.") | 618 | print(f"Training entire text encoder.") |
619 | |||
620 | unpatch_managed_embeddings(text_encoder) | ||
614 | else: | 621 | else: |
615 | print(f"Training added text embeddings") | 622 | print(f"Training added text embeddings") |
616 | 623 | ||
617 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | ||
618 | |||
619 | text_encoder.text_model.encoder.requires_grad_(False) | 624 | text_encoder.text_model.encoder.requires_grad_(False) |
620 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 625 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
621 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 626 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
@@ -624,15 +629,14 @@ def main(): | |||
624 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 629 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
625 | 630 | ||
626 | if args.scale_lr: | 631 | if args.scale_lr: |
627 | args.learning_rate_unet = ( | 632 | args.learning_rate = ( |
628 | args.learning_rate_unet * args.gradient_accumulation_steps * | 633 | args.learning_rate * args.gradient_accumulation_steps * |
629 | args.train_batch_size * accelerator.num_processes | ||
630 | ) | ||
631 | args.learning_rate_text = ( | ||
632 | args.learning_rate_text * args.gradient_accumulation_steps * | ||
633 | args.train_batch_size * accelerator.num_processes | 634 | args.train_batch_size * accelerator.num_processes |
634 | ) | 635 | ) |
635 | 636 | ||
637 | if args.find_lr: | ||
638 | args.learning_rate = 1e2 | ||
639 | |||
636 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 640 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
637 | if args.use_8bit_adam: | 641 | if args.use_8bit_adam: |
638 | try: | 642 | try: |
@@ -647,20 +651,19 @@ def main(): | |||
647 | if args.train_text_encoder: | 651 | if args.train_text_encoder: |
648 | text_encoder_params_to_optimize = text_encoder.parameters() | 652 | text_encoder_params_to_optimize = text_encoder.parameters() |
649 | else: | 653 | else: |
650 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() | 654 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
651 | 655 | ||
652 | # Initialize the optimizer | 656 | # Initialize the optimizer |
653 | optimizer = optimizer_class( | 657 | optimizer = optimizer_class( |
654 | [ | 658 | [ |
655 | { | 659 | { |
656 | 'params': unet.parameters(), | 660 | 'params': unet.parameters(), |
657 | 'lr': args.learning_rate_unet, | ||
658 | }, | 661 | }, |
659 | { | 662 | { |
660 | 'params': text_encoder_params_to_optimize, | 663 | 'params': text_encoder_params_to_optimize, |
661 | 'lr': args.learning_rate_text, | ||
662 | } | 664 | } |
663 | ], | 665 | ], |
666 | lr=args.learning_rate, | ||
664 | betas=(args.adam_beta1, args.adam_beta2), | 667 | betas=(args.adam_beta1, args.adam_beta2), |
665 | weight_decay=args.adam_weight_decay, | 668 | weight_decay=args.adam_weight_decay, |
666 | eps=args.adam_epsilon, | 669 | eps=args.adam_epsilon, |
@@ -824,6 +827,58 @@ def main(): | |||
824 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 827 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
825 | val_steps = num_val_steps_per_epoch * num_epochs | 828 | val_steps = num_val_steps_per_epoch * num_epochs |
826 | 829 | ||
830 | def loop(batch): | ||
831 | # Convert images to latent space | ||
832 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
833 | latents = latents * 0.18215 | ||
834 | |||
835 | # Sample noise that we'll add to the latents | ||
836 | noise = torch.randn_like(latents) | ||
837 | bsz = latents.shape[0] | ||
838 | # Sample a random timestep for each image | ||
839 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
840 | (bsz,), device=latents.device) | ||
841 | timesteps = timesteps.long() | ||
842 | |||
843 | # Add noise to the latents according to the noise magnitude at each timestep | ||
844 | # (this is the forward diffusion process) | ||
845 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
846 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
847 | |||
848 | # Get the text embedding for conditioning | ||
849 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
850 | |||
851 | # Predict the noise residual | ||
852 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
853 | |||
854 | # Get the target for loss depending on the prediction type | ||
855 | if noise_scheduler.config.prediction_type == "epsilon": | ||
856 | target = noise | ||
857 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
858 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
859 | else: | ||
860 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
861 | |||
862 | if args.num_class_images != 0: | ||
863 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
864 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
865 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
866 | |||
867 | # Compute instance loss | ||
868 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
869 | |||
870 | # Compute prior loss | ||
871 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
872 | |||
873 | # Add the prior loss to the instance loss. | ||
874 | loss = loss + args.prior_loss_weight * prior_loss | ||
875 | else: | ||
876 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
877 | |||
878 | acc = (model_pred == target).float().mean() | ||
879 | |||
880 | return loss, acc, bsz | ||
881 | |||
827 | # We need to initialize the trackers we use, and also store our configuration. | 882 | # We need to initialize the trackers we use, and also store our configuration. |
828 | # The trackers initializes automatically on the main process. | 883 | # The trackers initializes automatically on the main process. |
829 | if accelerator.is_main_process: | 884 | if accelerator.is_main_process: |
@@ -836,6 +891,15 @@ def main(): | |||
836 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | 891 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
837 | accelerator.init_trackers("dreambooth", config=config) | 892 | accelerator.init_trackers("dreambooth", config=config) |
838 | 893 | ||
894 | if args.find_lr: | ||
895 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
896 | lr_finder.run(min_lr=1e-4) | ||
897 | |||
898 | plt.savefig(basepath.joinpath("lr.png")) | ||
899 | plt.close() | ||
900 | |||
901 | quit() | ||
902 | |||
839 | # Train! | 903 | # Train! |
840 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 904 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
841 | 905 | ||
@@ -893,58 +957,6 @@ def main(): | |||
893 | ) | 957 | ) |
894 | global_progress_bar.set_description("Total progress") | 958 | global_progress_bar.set_description("Total progress") |
895 | 959 | ||
896 | def loop(batch): | ||
897 | # Convert images to latent space | ||
898 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
899 | latents = latents * 0.18215 | ||
900 | |||
901 | # Sample noise that we'll add to the latents | ||
902 | noise = torch.randn_like(latents) | ||
903 | bsz = latents.shape[0] | ||
904 | # Sample a random timestep for each image | ||
905 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
906 | (bsz,), device=latents.device) | ||
907 | timesteps = timesteps.long() | ||
908 | |||
909 | # Add noise to the latents according to the noise magnitude at each timestep | ||
910 | # (this is the forward diffusion process) | ||
911 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
912 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
913 | |||
914 | # Get the text embedding for conditioning | ||
915 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
916 | |||
917 | # Predict the noise residual | ||
918 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
919 | |||
920 | # Get the target for loss depending on the prediction type | ||
921 | if noise_scheduler.config.prediction_type == "epsilon": | ||
922 | target = noise | ||
923 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
924 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
925 | else: | ||
926 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
927 | |||
928 | if args.num_class_images != 0: | ||
929 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
930 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
931 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
932 | |||
933 | # Compute instance loss | ||
934 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
935 | |||
936 | # Compute prior loss | ||
937 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
938 | |||
939 | # Add the prior loss to the instance loss. | ||
940 | loss = loss + args.prior_loss_weight * prior_loss | ||
941 | else: | ||
942 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
943 | |||
944 | acc = (model_pred == target).float().mean() | ||
945 | |||
946 | return loss, acc, bsz | ||
947 | |||
948 | try: | 960 | try: |
949 | for epoch in range(num_epochs): | 961 | for epoch in range(num_epochs): |
950 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 962 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -993,8 +1005,7 @@ def main(): | |||
993 | "train/acc": avg_acc.avg.item(), | 1005 | "train/acc": avg_acc.avg.item(), |
994 | "train/cur_loss": loss.item(), | 1006 | "train/cur_loss": loss.item(), |
995 | "train/cur_acc": acc.item(), | 1007 | "train/cur_acc": acc.item(), |
996 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1008 | "lr": lr_scheduler.get_last_lr()[0] |
997 | "lr/text": lr_scheduler.get_last_lr()[1] | ||
998 | } | 1009 | } |
999 | if args.use_ema: | 1010 | if args.use_ema: |
1000 | logs["ema_decay"] = 1 - ema_unet.decay | 1011 | logs["ema_decay"] = 1 - ema_unet.decay |
@@ -1011,12 +1022,21 @@ def main(): | |||
1011 | unet.eval() | 1022 | unet.eval() |
1012 | text_encoder.eval() | 1023 | text_encoder.eval() |
1013 | 1024 | ||
1025 | cur_loss_val = AverageMeter() | ||
1026 | cur_acc_val = AverageMeter() | ||
1027 | |||
1014 | with torch.inference_mode(): | 1028 | with torch.inference_mode(): |
1015 | for step, batch in enumerate(val_dataloader): | 1029 | for step, batch in enumerate(val_dataloader): |
1016 | loss, acc, bsz = loop(batch) | 1030 | loss, acc, bsz = loop(batch) |
1017 | 1031 | ||
1018 | avg_loss_val.update(loss.detach_(), bsz) | 1032 | loss = loss.detach_() |
1019 | avg_acc_val.update(acc.detach_(), bsz) | 1033 | acc = acc.detach_() |
1034 | |||
1035 | cur_loss_val.update(loss, bsz) | ||
1036 | cur_acc_val.update(acc, bsz) | ||
1037 | |||
1038 | avg_loss_val.update(loss, bsz) | ||
1039 | avg_acc_val.update(acc, bsz) | ||
1020 | 1040 | ||
1021 | local_progress_bar.update(1) | 1041 | local_progress_bar.update(1) |
1022 | global_progress_bar.update(1) | 1042 | global_progress_bar.update(1) |
@@ -1029,20 +1049,20 @@ def main(): | |||
1029 | } | 1049 | } |
1030 | local_progress_bar.set_postfix(**logs) | 1050 | local_progress_bar.set_postfix(**logs) |
1031 | 1051 | ||
1032 | accelerator.log({ | 1052 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
1033 | "val/loss": avg_loss_val.avg.item(), | 1053 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
1034 | "val/acc": avg_acc_val.avg.item(), | 1054 | |
1035 | }, step=global_step) | 1055 | accelerator.log(logs, step=global_step) |
1036 | 1056 | ||
1037 | local_progress_bar.clear() | 1057 | local_progress_bar.clear() |
1038 | global_progress_bar.clear() | 1058 | global_progress_bar.clear() |
1039 | 1059 | ||
1040 | if avg_acc_val.avg.item() > max_acc_val: | ||
1041 | accelerator.print( | ||
1042 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
1043 | max_acc_val = avg_acc_val.avg.item() | ||
1044 | |||
1045 | if accelerator.is_main_process: | 1060 | if accelerator.is_main_process: |
1061 | if avg_acc_val.avg.item() > max_acc_val: | ||
1062 | accelerator.print( | ||
1063 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
1064 | max_acc_val = avg_acc_val.avg.item() | ||
1065 | |||
1046 | if (epoch + 1) % args.sample_frequency == 0: | 1066 | if (epoch + 1) % args.sample_frequency == 0: |
1047 | checkpointer.save_samples(global_step, args.sample_steps) | 1067 | checkpointer.save_samples(global_step, args.sample_steps) |
1048 | 1068 | ||
diff --git a/train_ti.py b/train_ti.py index 20a3190..775b918 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1,5 +1,4 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | ||
3 | import math | 2 | import math |
4 | import datetime | 3 | import datetime |
5 | import logging | 4 | import logging |
@@ -156,6 +155,12 @@ def parse_args(): | |||
156 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
157 | ) | 156 | ) |
158 | parser.add_argument( | 157 | parser.add_argument( |
158 | "--vector_shuffle", | ||
159 | type=str, | ||
160 | default="auto", | ||
161 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
162 | ) | ||
163 | parser.add_argument( | ||
159 | "--dataloader_num_workers", | 164 | "--dataloader_num_workers", |
160 | type=int, | 165 | type=int, |
161 | default=0, | 166 | default=0, |
@@ -245,7 +250,7 @@ def parse_args(): | |||
245 | parser.add_argument( | 250 | parser.add_argument( |
246 | "--lr_annealing_exp", | 251 | "--lr_annealing_exp", |
247 | type=int, | 252 | type=int, |
248 | default=2, | 253 | default=1, |
249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
250 | ) | 255 | ) |
251 | parser.add_argument( | 256 | parser.add_argument( |
@@ -502,20 +507,14 @@ def main(): | |||
502 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 507 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
503 | basepath.mkdir(parents=True, exist_ok=True) | 508 | basepath.mkdir(parents=True, exist_ok=True) |
504 | 509 | ||
505 | if args.find_lr: | 510 | accelerator = Accelerator( |
506 | accelerator = Accelerator( | 511 | log_with=LoggerType.TENSORBOARD, |
507 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 512 | logging_dir=f"{basepath}", |
508 | mixed_precision=args.mixed_precision | 513 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
509 | ) | 514 | mixed_precision=args.mixed_precision |
510 | else: | 515 | ) |
511 | accelerator = Accelerator( | ||
512 | log_with=LoggerType.TENSORBOARD, | ||
513 | logging_dir=f"{basepath}", | ||
514 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
515 | mixed_precision=args.mixed_precision | ||
516 | ) | ||
517 | 516 | ||
518 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 517 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
519 | 518 | ||
520 | args.seed = args.seed or (torch.random.seed() >> 32) | 519 | args.seed = args.seed or (torch.random.seed() >> 32) |
521 | set_seed(args.seed) | 520 | set_seed(args.seed) |
@@ -534,7 +533,7 @@ def main(): | |||
534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 533 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
535 | args.pretrained_model_name_or_path, subfolder='scheduler') | 534 | args.pretrained_model_name_or_path, subfolder='scheduler') |
536 | 535 | ||
537 | tokenizer.set_use_vector_shuffle(True) | 536 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
538 | 537 | ||
539 | vae.enable_slicing() | 538 | vae.enable_slicing() |
540 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) |
@@ -585,7 +584,7 @@ def main(): | |||
585 | ) | 584 | ) |
586 | 585 | ||
587 | if args.find_lr: | 586 | if args.find_lr: |
588 | args.learning_rate = 1e3 | 587 | args.learning_rate = 1e2 |
589 | 588 | ||
590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 589 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
591 | if args.use_8bit_adam: | 590 | if args.use_8bit_adam: |
@@ -830,15 +829,6 @@ def main(): | |||
830 | 829 | ||
831 | return loss, acc, bsz | 830 | return loss, acc, bsz |
832 | 831 | ||
833 | if args.find_lr: | ||
834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
835 | lr_finder.run(min_lr=1e-4) | ||
836 | |||
837 | plt.savefig(basepath.joinpath("lr.png")) | ||
838 | plt.close() | ||
839 | |||
840 | quit() | ||
841 | |||
842 | # We need to initialize the trackers we use, and also store our configuration. | 832 | # We need to initialize the trackers we use, and also store our configuration. |
843 | # The trackers initializes automatically on the main process. | 833 | # The trackers initializes automatically on the main process. |
844 | if accelerator.is_main_process: | 834 | if accelerator.is_main_process: |
@@ -852,6 +842,15 @@ def main(): | |||
852 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | 842 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
853 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) |
854 | 844 | ||
845 | if args.find_lr: | ||
846 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
847 | lr_finder.run(min_lr=1e-4) | ||
848 | |||
849 | plt.savefig(basepath.joinpath("lr.png")) | ||
850 | plt.close() | ||
851 | |||
852 | quit() | ||
853 | |||
855 | # Train! | 854 | # Train! |
856 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 855 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
857 | 856 | ||
diff --git a/training/lr.py b/training/lr.py index 0c5ce9e..3abd2f2 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -102,6 +102,12 @@ class LRFinder(): | |||
102 | losses.append(loss) | 102 | losses.append(loss) |
103 | accs.append(acc) | 103 | accs.append(acc) |
104 | 104 | ||
105 | self.accelerator.log({ | ||
106 | "loss": loss, | ||
107 | "acc": acc, | ||
108 | "lr": lr, | ||
109 | }, step=epoch) | ||
110 | |||
105 | progress_bar.set_postfix({ | 111 | progress_bar.set_postfix({ |
106 | "loss": loss, | 112 | "loss": loss, |
107 | "loss/best": best_loss, | 113 | "loss/best": best_loss, |
diff --git a/training/optimization.py b/training/optimization.py index 3340544..a79944f 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -15,7 +15,7 @@ def get_one_cycle_schedule( | |||
15 | warmup: Literal["cos", "linear"] = "cos", | 15 | warmup: Literal["cos", "linear"] = "cos", |
16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
17 | warmup_exp: int = 1, | 17 | warmup_exp: int = 1, |
18 | annealing_exp: int = 2, | 18 | annealing_exp: int = 1, |
19 | min_lr: int = 0.04, | 19 | min_lr: int = 0.04, |
20 | mid_point: int = 0.3, | 20 | mid_point: int = 0.3, |
21 | last_epoch: int = -1 | 21 | last_epoch: int = -1 |