summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
commitadc52fb8821a496bc8d78235bf10466b39df03e0 (patch)
tree8a6337a6ac10cbe76c55514ab559c647e69fb1aa
parentFixed accuracy calc, other improvements (diff)
downloadtextual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip
Updates
-rw-r--r--models/clip/embeddings.py11
-rw-r--r--models/clip/tokenizer.py76
-rw-r--r--train_dreambooth.py228
-rw-r--r--train_ti.py51
-rw-r--r--training/lr.py6
-rw-r--r--training/optimization.py2
6 files changed, 227 insertions, 147 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index f90e7c2..8602142 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe
120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
121 text_encoder.text_model.embeddings = text_embeddings 121 text_encoder.text_model.embeddings = text_embeddings
122 return text_embeddings 122 return text_embeddings
123
124
125def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings:
126 text_encoder.text_model.embeddings.make_permanent()
127
128 text_embeddings = CLIPTextEmbeddings(text_encoder.config)
129 text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
130 text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
131 text_encoder.text_model.embeddings = text_embeddings
132
133 return text_embeddings
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 37d69a9..ed9774e 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -1,11 +1,54 @@
1import copy 1import copy
2from typing import NamedTuple, Union 2from typing import NamedTuple, Union, Literal
3 3
4import numpy as np 4import numpy as np
5 5
6from transformers import CLIPTokenizer 6from transformers import CLIPTokenizer
7 7
8 8
9def shuffle_all(tokens: list[int]):
10 if len(tokens) >= 2:
11 tokens = copy.copy(tokens)
12 np.random.shuffle(tokens)
13 return tokens
14
15
16def shuffle_leading(tokens: list[int]):
17 if len(tokens) >= 3:
18 subtokens = tokens[:-1]
19 np.random.shuffle(subtokens)
20 tokens = subtokens + tokens[-1:]
21 return tokens
22
23
24def shuffle_trailing(tokens: list[int]):
25 if len(tokens) >= 3:
26 subtokens = tokens[1:]
27 np.random.shuffle(subtokens)
28 tokens = tokens[:1] + subtokens
29 return tokens
30
31
32def shuffle_between(tokens: list[int]):
33 if len(tokens) >= 4:
34 subtokens = tokens[1:-1]
35 np.random.shuffle(subtokens)
36 tokens = tokens[:1] + subtokens + tokens[-1:]
37 return tokens
38
39
40def shuffle_none(tokens: list[int]):
41 return tokens
42
43
44def shuffle_auto(tokens: list[int]):
45 if len(tokens) >= 4:
46 return shuffle_between(tokens)
47 if len(tokens) >= 3:
48 return shuffle_trailing(tokens)
49 return shuffle_all(tokens)
50
51
9class MultiCLIPTokenizerItem(NamedTuple): 52class MultiCLIPTokenizerItem(NamedTuple):
10 token: str 53 token: str
11 ids: list[int] 54 ids: list[int]
@@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer):
15 def __init__(self, *args, **kwargs): 58 def __init__(self, *args, **kwargs):
16 super().__init__(*args, **kwargs) 59 super().__init__(*args, **kwargs)
17 self.token_map: dict[int, list[int]] = {} 60 self.token_map: dict[int, list[int]] = {}
18 self.vector_shuffle = False 61 self.vector_shuffle = shuffle_none
19 62
20 def set_use_vector_shuffle(self, enable: bool): 63 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
21 self.vector_shuffle = enable 64 if algorithm == "leading":
65 self.vector_shuffle = shuffle_leading
66 elif algorithm == "trailing":
67 self.vector_shuffle = shuffle_trailing
68 elif algorithm == "between":
69 self.vector_shuffle = shuffle_between
70 elif algorithm == "auto":
71 self.vector_shuffle = shuffle_auto
72 elif algorithm == True or algorithm == "all":
73 self.vector_shuffle = shuffle_all
74 else:
75 self.vector_shuffle = shuffle_none
22 76
23 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: 77 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
24 if isinstance(new_tokens, list): 78 if isinstance(new_tokens, list):
@@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
47 return MultiCLIPTokenizerItem(new_tokens, ids) 101 return MultiCLIPTokenizerItem(new_tokens, ids)
48 102
49 def expand_id(self, id: int): 103 def expand_id(self, id: int):
50 if id in self.token_map: 104 return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id]
51 tokens = self.token_map[id]
52
53 if self.vector_shuffle and len(tokens) > 2:
54 subtokens = tokens[1:-1]
55 np.random.shuffle(subtokens)
56 tokens = tokens[:1] + subtokens + tokens[-1:]
57
58 return tokens
59 else:
60 return [id]
61 105
62 def expand_ids(self, ids: list[int]): 106 def expand_ids(self, ids: list[int]):
63 return [ 107 return [
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1ebcfe3..b07de31 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -3,7 +3,6 @@ import itertools
3import math 3import math
4import datetime 4import datetime
5import logging 5import logging
6import json
7from pathlib import Path 6from pathlib import Path
8 7
9import torch 8import torch
@@ -15,18 +14,21 @@ from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from tqdm.auto import tqdm 19from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings, load_config 23from common import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
31from models.clip.tokenizer import MultiCLIPTokenizer
30 32
31logger = get_logger(__name__) 33logger = get_logger(__name__)
32 34
@@ -106,6 +108,12 @@ def parse_args():
106 help="Tag dropout probability.", 108 help="Tag dropout probability.",
107 ) 109 )
108 parser.add_argument( 110 parser.add_argument(
111 "--vector_shuffle",
112 type=str,
113 default="auto",
114 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
115 )
116 parser.add_argument(
109 "--num_class_images", 117 "--num_class_images",
110 type=int, 118 type=int,
111 default=1, 119 default=1,
@@ -193,13 +201,12 @@ def parse_args():
193 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 201 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
194 ) 202 )
195 parser.add_argument( 203 parser.add_argument(
196 "--learning_rate_unet", 204 "--find_lr",
197 type=float, 205 action="store_true",
198 default=2e-6, 206 help="Automatically find a learning rate (no training).",
199 help="Initial learning rate (after the potential warmup period) to use.",
200 ) 207 )
201 parser.add_argument( 208 parser.add_argument(
202 "--learning_rate_text", 209 "--learning_rate",
203 type=float, 210 type=float,
204 default=2e-6, 211 default=2e-6,
205 help="Initial learning rate (after the potential warmup period) to use.", 212 help="Initial learning rate (after the potential warmup period) to use.",
@@ -546,9 +553,9 @@ def main():
546 553
547 # Load the tokenizer and add the placeholder token as a additional special token 554 # Load the tokenizer and add the placeholder token as a additional special token
548 if args.tokenizer_name: 555 if args.tokenizer_name:
549 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 556 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
550 elif args.pretrained_model_name_or_path: 557 elif args.pretrained_model_name_or_path:
551 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 558 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
552 559
553 # Load models and create wrapper for stable diffusion 560 # Load models and create wrapper for stable diffusion
554 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 561 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -558,6 +565,8 @@ def main():
558 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 565 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
559 args.pretrained_model_name_or_path, subfolder='scheduler') 566 args.pretrained_model_name_or_path, subfolder='scheduler')
560 567
568 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
569
561 vae.enable_slicing() 570 vae.enable_slicing()
562 vae.set_use_memory_efficient_attention_xformers(True) 571 vae.set_use_memory_efficient_attention_xformers(True)
563 unet.set_use_memory_efficient_attention_xformers(True) 572 unet.set_use_memory_efficient_attention_xformers(True)
@@ -576,46 +585,42 @@ def main():
576 device=accelerator.device 585 device=accelerator.device
577 ) 586 )
578 587
579 # Freeze text_encoder and vae 588 embeddings = patch_managed_embeddings(text_encoder)
580 vae.requires_grad_(False)
581 589
582 if args.embeddings_dir is not None: 590 if args.embeddings_dir is not None:
583 embeddings_dir = Path(args.embeddings_dir) 591 embeddings_dir = Path(args.embeddings_dir)
584 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 592 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
585 raise ValueError("--embeddings_dir must point to an existing directory") 593 raise ValueError("--embeddings_dir must point to an existing directory")
586 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) 594
587 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 595 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
596 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
588 597
589 if len(args.placeholder_token) != 0: 598 if len(args.placeholder_token) != 0:
590 # Convert the initializer_token, placeholder_token to ids 599 # Convert the initializer_token, placeholder_token to ids
591 initializer_token_ids = torch.stack([ 600 initializer_token_ids = [
592 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 601 tokenizer.encode(token, add_special_tokens=False)
593 for token in args.initializer_token 602 for token in args.initializer_token
594 ]) 603 ]
595
596 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
597 print(f"Added {num_added_tokens} new tokens.")
598 604
599 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 605 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
606 embeddings.resize(len(tokenizer))
600 607
601 # Resize the token embeddings as we are adding new special tokens to the tokenizer 608 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids):
602 text_encoder.resize_token_embeddings(len(tokenizer)) 609 embeddings.add_embed(new_token.ids, init_ids)
603 610
604 token_embeds = text_encoder.get_input_embeddings().weight.data 611 print(f"Added {len(new_tokens)} new tokens.")
605 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
606
607 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
608 token_embeds[token_id] = embeddings
609 else: 612 else:
610 placeholder_token_id = [] 613 placeholder_token_id = []
611 614
615 vae.requires_grad_(False)
616
612 if args.train_text_encoder: 617 if args.train_text_encoder:
613 print(f"Training entire text encoder.") 618 print(f"Training entire text encoder.")
619
620 unpatch_managed_embeddings(text_encoder)
614 else: 621 else:
615 print(f"Training added text embeddings") 622 print(f"Training added text embeddings")
616 623
617 patch_trainable_embeddings(text_encoder, placeholder_token_id)
618
619 text_encoder.text_model.encoder.requires_grad_(False) 624 text_encoder.text_model.encoder.requires_grad_(False)
620 text_encoder.text_model.final_layer_norm.requires_grad_(False) 625 text_encoder.text_model.final_layer_norm.requires_grad_(False)
621 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 626 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
@@ -624,15 +629,14 @@ def main():
624 prompt_processor = PromptProcessor(tokenizer, text_encoder) 629 prompt_processor = PromptProcessor(tokenizer, text_encoder)
625 630
626 if args.scale_lr: 631 if args.scale_lr:
627 args.learning_rate_unet = ( 632 args.learning_rate = (
628 args.learning_rate_unet * args.gradient_accumulation_steps * 633 args.learning_rate * args.gradient_accumulation_steps *
629 args.train_batch_size * accelerator.num_processes
630 )
631 args.learning_rate_text = (
632 args.learning_rate_text * args.gradient_accumulation_steps *
633 args.train_batch_size * accelerator.num_processes 634 args.train_batch_size * accelerator.num_processes
634 ) 635 )
635 636
637 if args.find_lr:
638 args.learning_rate = 1e2
639
636 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 640 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
637 if args.use_8bit_adam: 641 if args.use_8bit_adam:
638 try: 642 try:
@@ -647,20 +651,19 @@ def main():
647 if args.train_text_encoder: 651 if args.train_text_encoder:
648 text_encoder_params_to_optimize = text_encoder.parameters() 652 text_encoder_params_to_optimize = text_encoder.parameters()
649 else: 653 else:
650 text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() 654 text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters()
651 655
652 # Initialize the optimizer 656 # Initialize the optimizer
653 optimizer = optimizer_class( 657 optimizer = optimizer_class(
654 [ 658 [
655 { 659 {
656 'params': unet.parameters(), 660 'params': unet.parameters(),
657 'lr': args.learning_rate_unet,
658 }, 661 },
659 { 662 {
660 'params': text_encoder_params_to_optimize, 663 'params': text_encoder_params_to_optimize,
661 'lr': args.learning_rate_text,
662 } 664 }
663 ], 665 ],
666 lr=args.learning_rate,
664 betas=(args.adam_beta1, args.adam_beta2), 667 betas=(args.adam_beta1, args.adam_beta2),
665 weight_decay=args.adam_weight_decay, 668 weight_decay=args.adam_weight_decay,
666 eps=args.adam_epsilon, 669 eps=args.adam_epsilon,
@@ -824,6 +827,58 @@ def main():
824 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 827 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
825 val_steps = num_val_steps_per_epoch * num_epochs 828 val_steps = num_val_steps_per_epoch * num_epochs
826 829
830 def loop(batch):
831 # Convert images to latent space
832 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
833 latents = latents * 0.18215
834
835 # Sample noise that we'll add to the latents
836 noise = torch.randn_like(latents)
837 bsz = latents.shape[0]
838 # Sample a random timestep for each image
839 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
840 (bsz,), device=latents.device)
841 timesteps = timesteps.long()
842
843 # Add noise to the latents according to the noise magnitude at each timestep
844 # (this is the forward diffusion process)
845 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
846 noisy_latents = noisy_latents.to(dtype=unet.dtype)
847
848 # Get the text embedding for conditioning
849 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
850
851 # Predict the noise residual
852 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
853
854 # Get the target for loss depending on the prediction type
855 if noise_scheduler.config.prediction_type == "epsilon":
856 target = noise
857 elif noise_scheduler.config.prediction_type == "v_prediction":
858 target = noise_scheduler.get_velocity(latents, noise, timesteps)
859 else:
860 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
861
862 if args.num_class_images != 0:
863 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
864 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
865 target, target_prior = torch.chunk(target, 2, dim=0)
866
867 # Compute instance loss
868 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
869
870 # Compute prior loss
871 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
872
873 # Add the prior loss to the instance loss.
874 loss = loss + args.prior_loss_weight * prior_loss
875 else:
876 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
877
878 acc = (model_pred == target).float().mean()
879
880 return loss, acc, bsz
881
827 # We need to initialize the trackers we use, and also store our configuration. 882 # We need to initialize the trackers we use, and also store our configuration.
828 # The trackers initializes automatically on the main process. 883 # The trackers initializes automatically on the main process.
829 if accelerator.is_main_process: 884 if accelerator.is_main_process:
@@ -836,6 +891,15 @@ def main():
836 config["exclude_collections"] = " ".join(config["exclude_collections"]) 891 config["exclude_collections"] = " ".join(config["exclude_collections"])
837 accelerator.init_trackers("dreambooth", config=config) 892 accelerator.init_trackers("dreambooth", config=config)
838 893
894 if args.find_lr:
895 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
896 lr_finder.run(min_lr=1e-4)
897
898 plt.savefig(basepath.joinpath("lr.png"))
899 plt.close()
900
901 quit()
902
839 # Train! 903 # Train!
840 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 904 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
841 905
@@ -893,58 +957,6 @@ def main():
893 ) 957 )
894 global_progress_bar.set_description("Total progress") 958 global_progress_bar.set_description("Total progress")
895 959
896 def loop(batch):
897 # Convert images to latent space
898 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
899 latents = latents * 0.18215
900
901 # Sample noise that we'll add to the latents
902 noise = torch.randn_like(latents)
903 bsz = latents.shape[0]
904 # Sample a random timestep for each image
905 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
906 (bsz,), device=latents.device)
907 timesteps = timesteps.long()
908
909 # Add noise to the latents according to the noise magnitude at each timestep
910 # (this is the forward diffusion process)
911 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
912 noisy_latents = noisy_latents.to(dtype=unet.dtype)
913
914 # Get the text embedding for conditioning
915 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
916
917 # Predict the noise residual
918 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
919
920 # Get the target for loss depending on the prediction type
921 if noise_scheduler.config.prediction_type == "epsilon":
922 target = noise
923 elif noise_scheduler.config.prediction_type == "v_prediction":
924 target = noise_scheduler.get_velocity(latents, noise, timesteps)
925 else:
926 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
927
928 if args.num_class_images != 0:
929 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
930 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
931 target, target_prior = torch.chunk(target, 2, dim=0)
932
933 # Compute instance loss
934 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
935
936 # Compute prior loss
937 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
938
939 # Add the prior loss to the instance loss.
940 loss = loss + args.prior_loss_weight * prior_loss
941 else:
942 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
943
944 acc = (model_pred == target).float().mean()
945
946 return loss, acc, bsz
947
948 try: 960 try:
949 for epoch in range(num_epochs): 961 for epoch in range(num_epochs):
950 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 962 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -993,8 +1005,7 @@ def main():
993 "train/acc": avg_acc.avg.item(), 1005 "train/acc": avg_acc.avg.item(),
994 "train/cur_loss": loss.item(), 1006 "train/cur_loss": loss.item(),
995 "train/cur_acc": acc.item(), 1007 "train/cur_acc": acc.item(),
996 "lr/unet": lr_scheduler.get_last_lr()[0], 1008 "lr": lr_scheduler.get_last_lr()[0]
997 "lr/text": lr_scheduler.get_last_lr()[1]
998 } 1009 }
999 if args.use_ema: 1010 if args.use_ema:
1000 logs["ema_decay"] = 1 - ema_unet.decay 1011 logs["ema_decay"] = 1 - ema_unet.decay
@@ -1011,12 +1022,21 @@ def main():
1011 unet.eval() 1022 unet.eval()
1012 text_encoder.eval() 1023 text_encoder.eval()
1013 1024
1025 cur_loss_val = AverageMeter()
1026 cur_acc_val = AverageMeter()
1027
1014 with torch.inference_mode(): 1028 with torch.inference_mode():
1015 for step, batch in enumerate(val_dataloader): 1029 for step, batch in enumerate(val_dataloader):
1016 loss, acc, bsz = loop(batch) 1030 loss, acc, bsz = loop(batch)
1017 1031
1018 avg_loss_val.update(loss.detach_(), bsz) 1032 loss = loss.detach_()
1019 avg_acc_val.update(acc.detach_(), bsz) 1033 acc = acc.detach_()
1034
1035 cur_loss_val.update(loss, bsz)
1036 cur_acc_val.update(acc, bsz)
1037
1038 avg_loss_val.update(loss, bsz)
1039 avg_acc_val.update(acc, bsz)
1020 1040
1021 local_progress_bar.update(1) 1041 local_progress_bar.update(1)
1022 global_progress_bar.update(1) 1042 global_progress_bar.update(1)
@@ -1029,20 +1049,20 @@ def main():
1029 } 1049 }
1030 local_progress_bar.set_postfix(**logs) 1050 local_progress_bar.set_postfix(**logs)
1031 1051
1032 accelerator.log({ 1052 logs["val/cur_loss"] = cur_loss_val.avg.item()
1033 "val/loss": avg_loss_val.avg.item(), 1053 logs["val/cur_acc"] = cur_acc_val.avg.item()
1034 "val/acc": avg_acc_val.avg.item(), 1054
1035 }, step=global_step) 1055 accelerator.log(logs, step=global_step)
1036 1056
1037 local_progress_bar.clear() 1057 local_progress_bar.clear()
1038 global_progress_bar.clear() 1058 global_progress_bar.clear()
1039 1059
1040 if avg_acc_val.avg.item() > max_acc_val:
1041 accelerator.print(
1042 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1043 max_acc_val = avg_acc_val.avg.item()
1044
1045 if accelerator.is_main_process: 1060 if accelerator.is_main_process:
1061 if avg_acc_val.avg.item() > max_acc_val:
1062 accelerator.print(
1063 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1064 max_acc_val = avg_acc_val.avg.item()
1065
1046 if (epoch + 1) % args.sample_frequency == 0: 1066 if (epoch + 1) % args.sample_frequency == 0:
1047 checkpointer.save_samples(global_step, args.sample_steps) 1067 checkpointer.save_samples(global_step, args.sample_steps)
1048 1068
diff --git a/train_ti.py b/train_ti.py
index 20a3190..775b918 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,5 +1,4 @@
1import argparse 1import argparse
2import itertools
3import math 2import math
4import datetime 3import datetime
5import logging 4import logging
@@ -156,6 +155,12 @@ def parse_args():
156 help="Tag dropout probability.", 155 help="Tag dropout probability.",
157 ) 156 )
158 parser.add_argument( 157 parser.add_argument(
158 "--vector_shuffle",
159 type=str,
160 default="auto",
161 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
162 )
163 parser.add_argument(
159 "--dataloader_num_workers", 164 "--dataloader_num_workers",
160 type=int, 165 type=int,
161 default=0, 166 default=0,
@@ -245,7 +250,7 @@ def parse_args():
245 parser.add_argument( 250 parser.add_argument(
246 "--lr_annealing_exp", 251 "--lr_annealing_exp",
247 type=int, 252 type=int,
248 default=2, 253 default=1,
249 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 254 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
250 ) 255 )
251 parser.add_argument( 256 parser.add_argument(
@@ -502,20 +507,14 @@ def main():
502 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 507 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
503 basepath.mkdir(parents=True, exist_ok=True) 508 basepath.mkdir(parents=True, exist_ok=True)
504 509
505 if args.find_lr: 510 accelerator = Accelerator(
506 accelerator = Accelerator( 511 log_with=LoggerType.TENSORBOARD,
507 gradient_accumulation_steps=args.gradient_accumulation_steps, 512 logging_dir=f"{basepath}",
508 mixed_precision=args.mixed_precision 513 gradient_accumulation_steps=args.gradient_accumulation_steps,
509 ) 514 mixed_precision=args.mixed_precision
510 else: 515 )
511 accelerator = Accelerator(
512 log_with=LoggerType.TENSORBOARD,
513 logging_dir=f"{basepath}",
514 gradient_accumulation_steps=args.gradient_accumulation_steps,
515 mixed_precision=args.mixed_precision
516 )
517 516
518 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 517 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
519 518
520 args.seed = args.seed or (torch.random.seed() >> 32) 519 args.seed = args.seed or (torch.random.seed() >> 32)
521 set_seed(args.seed) 520 set_seed(args.seed)
@@ -534,7 +533,7 @@ def main():
534 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 533 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
535 args.pretrained_model_name_or_path, subfolder='scheduler') 534 args.pretrained_model_name_or_path, subfolder='scheduler')
536 535
537 tokenizer.set_use_vector_shuffle(True) 536 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
538 537
539 vae.enable_slicing() 538 vae.enable_slicing()
540 vae.set_use_memory_efficient_attention_xformers(True) 539 vae.set_use_memory_efficient_attention_xformers(True)
@@ -585,7 +584,7 @@ def main():
585 ) 584 )
586 585
587 if args.find_lr: 586 if args.find_lr:
588 args.learning_rate = 1e3 587 args.learning_rate = 1e2
589 588
590 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 589 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
591 if args.use_8bit_adam: 590 if args.use_8bit_adam:
@@ -830,15 +829,6 @@ def main():
830 829
831 return loss, acc, bsz 830 return loss, acc, bsz
832 831
833 if args.find_lr:
834 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
835 lr_finder.run(min_lr=1e-4)
836
837 plt.savefig(basepath.joinpath("lr.png"))
838 plt.close()
839
840 quit()
841
842 # We need to initialize the trackers we use, and also store our configuration. 832 # We need to initialize the trackers we use, and also store our configuration.
843 # The trackers initializes automatically on the main process. 833 # The trackers initializes automatically on the main process.
844 if accelerator.is_main_process: 834 if accelerator.is_main_process:
@@ -852,6 +842,15 @@ def main():
852 config["exclude_collections"] = " ".join(config["exclude_collections"]) 842 config["exclude_collections"] = " ".join(config["exclude_collections"])
853 accelerator.init_trackers("textual_inversion", config=config) 843 accelerator.init_trackers("textual_inversion", config=config)
854 844
845 if args.find_lr:
846 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
847 lr_finder.run(min_lr=1e-4)
848
849 plt.savefig(basepath.joinpath("lr.png"))
850 plt.close()
851
852 quit()
853
855 # Train! 854 # Train!
856 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 855 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
857 856
diff --git a/training/lr.py b/training/lr.py
index 0c5ce9e..3abd2f2 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -102,6 +102,12 @@ class LRFinder():
102 losses.append(loss) 102 losses.append(loss)
103 accs.append(acc) 103 accs.append(acc)
104 104
105 self.accelerator.log({
106 "loss": loss,
107 "acc": acc,
108 "lr": lr,
109 }, step=epoch)
110
105 progress_bar.set_postfix({ 111 progress_bar.set_postfix({
106 "loss": loss, 112 "loss": loss,
107 "loss/best": best_loss, 113 "loss/best": best_loss,
diff --git a/training/optimization.py b/training/optimization.py
index 3340544..a79944f 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -15,7 +15,7 @@ def get_one_cycle_schedule(
15 warmup: Literal["cos", "linear"] = "cos", 15 warmup: Literal["cos", "linear"] = "cos",
16 annealing: Literal["cos", "half_cos", "linear"] = "cos", 16 annealing: Literal["cos", "half_cos", "linear"] = "cos",
17 warmup_exp: int = 1, 17 warmup_exp: int = 1,
18 annealing_exp: int = 2, 18 annealing_exp: int = 1,
19 min_lr: int = 0.04, 19 min_lr: int = 0.04,
20 mid_point: int = 0.3, 20 mid_point: int = 0.3,
21 last_epoch: int = -1 21 last_epoch: int = -1