summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
committerVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
commita72b6260c117cabe4fcb2996cce4f870986df99b (patch)
tree7c9c7704c6ef60a4ab886d5acbce4e6e22398b56
parentFixed LR finder (diff)
downloadtextual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip
Added vector dropout
-rw-r--r--models/clip/tokenizer.py27
-rw-r--r--train_dreambooth.py24
-rw-r--r--train_ti.py24
-rw-r--r--training/lr.py9
4 files changed, 69 insertions, 15 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index bd0bd21..11a3df0 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -6,6 +6,12 @@ import numpy as np
6from transformers import CLIPTokenizer 6from transformers import CLIPTokenizer
7 7
8 8
9def dropout(tokens: list[int], dropout: float):
10 if dropout != 0:
11 tokens = [token for token in tokens if np.random.random() > dropout]
12 return tokens
13
14
9def shuffle_all(tokens: list[int]): 15def shuffle_all(tokens: list[int]):
10 if len(tokens) >= 2: 16 if len(tokens) >= 2:
11 tokens = copy.copy(tokens) 17 tokens = copy.copy(tokens)
@@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer):
59 super().__init__(*args, **kwargs) 65 super().__init__(*args, **kwargs)
60 66
61 self.token_map: dict[int, list[int]] = {} 67 self.token_map: dict[int, list[int]] = {}
62 self.vector_shuffle = shuffle_none 68 self.is_training = False
69 self.vector_shuffle = shuffle_auto
70 self.dropout = 0
71
72 def train(self):
73 self.is_training = True
74
75 def eval(self):
76 self.is_training = False
77
78 def set_dropout(self, dropout: float):
79 self.dropout = dropout
63 80
64 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): 81 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
65 if algorithm == "leading": 82 if algorithm == "leading":
@@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer):
105 return MultiCLIPTokenizerItem(new_tokens, ids) 122 return MultiCLIPTokenizerItem(new_tokens, ids)
106 123
107 def expand_id(self, id: int): 124 def expand_id(self, id: int):
108 return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] 125 if id in self.token_map:
126 ids = self.token_map[id]
127 if self.is_training:
128 ids = dropout(self.vector_shuffle(ids), self.dropout)
129 return ids
130 else:
131 return [id]
109 132
110 def expand_ids(self, ids: list[int]): 133 def expand_ids(self, ids: list[int]):
111 return [ 134 return [
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 218018b..f26b7f5 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -108,6 +108,12 @@ def parse_args():
108 help="Tag dropout probability.", 108 help="Tag dropout probability.",
109 ) 109 )
110 parser.add_argument( 110 parser.add_argument(
111 "--vector_dropout",
112 type=int,
113 default=0.1,
114 help="Vector dropout probability.",
115 )
116 parser.add_argument(
111 "--vector_shuffle", 117 "--vector_shuffle",
112 type=str, 118 type=str,
113 default="auto", 119 default="auto",
@@ -556,6 +562,8 @@ def main():
556 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) 562 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
557 elif args.pretrained_model_name_or_path: 563 elif args.pretrained_model_name_or_path:
558 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 564 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
565 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
566 tokenizer.set_dropout(args.vector_dropout)
559 567
560 # Load models and create wrapper for stable diffusion 568 # Load models and create wrapper for stable diffusion
561 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 569 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -826,6 +834,12 @@ def main():
826 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 834 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
827 val_steps = num_val_steps_per_epoch * num_epochs 835 val_steps = num_val_steps_per_epoch * num_epochs
828 836
837 def on_train():
838 tokenizer.train()
839
840 def on_eval():
841 tokenizer.eval()
842
829 def loop(batch): 843 def loop(batch):
830 # Convert images to latent space 844 # Convert images to latent space
831 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 845 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
@@ -898,8 +912,8 @@ def main():
898 train_dataloader, 912 train_dataloader,
899 val_dataloader, 913 val_dataloader,
900 loop, 914 loop,
901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 915 on_train=tokenizer.train,
902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 916 on_eval=tokenizer.eval,
903 ) 917 )
904 lr_finder.run(end_lr=1e2) 918 lr_finder.run(end_lr=1e2)
905 919
@@ -953,7 +967,7 @@ def main():
953 disable=not accelerator.is_local_main_process, 967 disable=not accelerator.is_local_main_process,
954 dynamic_ncols=True 968 dynamic_ncols=True
955 ) 969 )
956 local_progress_bar.set_description("Epoch X / Y") 970 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
957 971
958 global_progress_bar = tqdm( 972 global_progress_bar = tqdm(
959 range(args.max_train_steps + val_steps), 973 range(args.max_train_steps + val_steps),
@@ -976,7 +990,7 @@ def main():
976 text_encoder.train() 990 text_encoder.train()
977 elif epoch == args.train_text_encoder_epochs: 991 elif epoch == args.train_text_encoder_epochs:
978 text_encoder.requires_grad_(False) 992 text_encoder.requires_grad_(False)
979 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 993 on_train()
980 994
981 for step, batch in enumerate(train_dataloader): 995 for step, batch in enumerate(train_dataloader):
982 with accelerator.accumulate(unet): 996 with accelerator.accumulate(unet):
@@ -1030,7 +1044,7 @@ def main():
1030 1044
1031 unet.eval() 1045 unet.eval()
1032 text_encoder.eval() 1046 text_encoder.eval()
1033 tokenizer.set_use_vector_shuffle(False) 1047 on_eval()
1034 1048
1035 cur_loss_val = AverageMeter() 1049 cur_loss_val = AverageMeter()
1036 cur_acc_val = AverageMeter() 1050 cur_acc_val = AverageMeter()
diff --git a/train_ti.py b/train_ti.py
index 102c0fa..cacbbc7 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -155,6 +155,12 @@ def parse_args():
155 help="Tag dropout probability.", 155 help="Tag dropout probability.",
156 ) 156 )
157 parser.add_argument( 157 parser.add_argument(
158 "--vector_dropout",
159 type=int,
160 default=0.1,
161 help="Vector dropout probability.",
162 )
163 parser.add_argument(
158 "--vector_shuffle", 164 "--vector_shuffle",
159 type=str, 165 type=str,
160 default="auto", 166 default="auto",
@@ -526,6 +532,8 @@ def main():
526 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) 532 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
527 elif args.pretrained_model_name_or_path: 533 elif args.pretrained_model_name_or_path:
528 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 534 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
535 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
536 tokenizer.set_dropout(args.vector_dropout)
529 537
530 # Load models and create wrapper for stable diffusion 538 # Load models and create wrapper for stable diffusion
531 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 539 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -777,6 +785,12 @@ def main():
777 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 785 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
778 val_steps = num_val_steps_per_epoch * num_epochs 786 val_steps = num_val_steps_per_epoch * num_epochs
779 787
788 def on_train():
789 tokenizer.train()
790
791 def on_eval():
792 tokenizer.eval()
793
780 def loop(batch): 794 def loop(batch):
781 # Convert images to latent space 795 # Convert images to latent space
782 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 796 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
@@ -850,8 +864,8 @@ def main():
850 train_dataloader, 864 train_dataloader,
851 val_dataloader, 865 val_dataloader,
852 loop, 866 loop,
853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 867 on_train=on_train,
854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 868 on_eval=on_eval,
855 ) 869 )
856 lr_finder.run(end_lr=1e2) 870 lr_finder.run(end_lr=1e2)
857 871
@@ -903,7 +917,7 @@ def main():
903 disable=not accelerator.is_local_main_process, 917 disable=not accelerator.is_local_main_process,
904 dynamic_ncols=True 918 dynamic_ncols=True
905 ) 919 )
906 local_progress_bar.set_description("Epoch X / Y") 920 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
907 921
908 global_progress_bar = tqdm( 922 global_progress_bar = tqdm(
909 range(args.max_train_steps + val_steps), 923 range(args.max_train_steps + val_steps),
@@ -922,7 +936,7 @@ def main():
922 local_progress_bar.reset() 936 local_progress_bar.reset()
923 937
924 text_encoder.train() 938 text_encoder.train()
925 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 939 on_train()
926 940
927 for step, batch in enumerate(train_dataloader): 941 for step, batch in enumerate(train_dataloader):
928 with accelerator.accumulate(text_encoder): 942 with accelerator.accumulate(text_encoder):
@@ -963,7 +977,7 @@ def main():
963 accelerator.wait_for_everyone() 977 accelerator.wait_for_everyone()
964 978
965 text_encoder.eval() 979 text_encoder.eval()
966 tokenizer.set_use_vector_shuffle(False) 980 on_eval()
967 981
968 cur_loss_val = AverageMeter() 982 cur_loss_val = AverageMeter()
969 cur_acc_val = AverageMeter() 983 cur_acc_val = AverageMeter()
diff --git a/training/lr.py b/training/lr.py
index acc01a2..37588b6 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -58,7 +58,11 @@ class LRFinder():
58 losses = [] 58 losses = []
59 accs = [] 59 accs = []
60 60
61 lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) 61 lr_scheduler = get_exponential_schedule(
62 self.optimizer,
63 end_lr,
64 num_epochs * min(num_train_batches, len(self.train_dataloader))
65 )
62 66
63 steps = min(num_train_batches, len(self.train_dataloader)) 67 steps = min(num_train_batches, len(self.train_dataloader))
64 steps += min(num_val_batches, len(self.val_dataloader)) 68 steps += min(num_val_batches, len(self.val_dataloader))
@@ -90,6 +94,7 @@ class LRFinder():
90 self.accelerator.backward(loss) 94 self.accelerator.backward(loss)
91 95
92 self.optimizer.step() 96 self.optimizer.step()
97 lr_scheduler.step()
93 self.optimizer.zero_grad(set_to_none=True) 98 self.optimizer.zero_grad(set_to_none=True)
94 99
95 if self.accelerator.sync_gradients: 100 if self.accelerator.sync_gradients:
@@ -109,8 +114,6 @@ class LRFinder():
109 114
110 progress_bar.update(1) 115 progress_bar.update(1)
111 116
112 lr_scheduler.step()
113
114 loss = avg_loss.avg.item() 117 loss = avg_loss.avg.item()
115 acc = avg_acc.avg.item() 118 acc = avg_acc.avg.item()
116 119