diff options
-rw-r--r-- | models/clip/tokenizer.py | 27 | ||||
-rw-r--r-- | train_dreambooth.py | 24 | ||||
-rw-r--r-- | train_ti.py | 24 | ||||
-rw-r--r-- | training/lr.py | 9 |
4 files changed, 69 insertions, 15 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index bd0bd21..11a3df0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -6,6 +6,12 @@ import numpy as np | |||
6 | from transformers import CLIPTokenizer | 6 | from transformers import CLIPTokenizer |
7 | 7 | ||
8 | 8 | ||
9 | def dropout(tokens: list[int], dropout: float): | ||
10 | if dropout != 0: | ||
11 | tokens = [token for token in tokens if np.random.random() > dropout] | ||
12 | return tokens | ||
13 | |||
14 | |||
9 | def shuffle_all(tokens: list[int]): | 15 | def shuffle_all(tokens: list[int]): |
10 | if len(tokens) >= 2: | 16 | if len(tokens) >= 2: |
11 | tokens = copy.copy(tokens) | 17 | tokens = copy.copy(tokens) |
@@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
59 | super().__init__(*args, **kwargs) | 65 | super().__init__(*args, **kwargs) |
60 | 66 | ||
61 | self.token_map: dict[int, list[int]] = {} | 67 | self.token_map: dict[int, list[int]] = {} |
62 | self.vector_shuffle = shuffle_none | 68 | self.is_training = False |
69 | self.vector_shuffle = shuffle_auto | ||
70 | self.dropout = 0 | ||
71 | |||
72 | def train(self): | ||
73 | self.is_training = True | ||
74 | |||
75 | def eval(self): | ||
76 | self.is_training = False | ||
77 | |||
78 | def set_dropout(self, dropout: float): | ||
79 | self.dropout = dropout | ||
63 | 80 | ||
64 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): | 81 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): |
65 | if algorithm == "leading": | 82 | if algorithm == "leading": |
@@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
105 | return MultiCLIPTokenizerItem(new_tokens, ids) | 122 | return MultiCLIPTokenizerItem(new_tokens, ids) |
106 | 123 | ||
107 | def expand_id(self, id: int): | 124 | def expand_id(self, id: int): |
108 | return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] | 125 | if id in self.token_map: |
126 | ids = self.token_map[id] | ||
127 | if self.is_training: | ||
128 | ids = dropout(self.vector_shuffle(ids), self.dropout) | ||
129 | return ids | ||
130 | else: | ||
131 | return [id] | ||
109 | 132 | ||
110 | def expand_ids(self, ids: list[int]): | 133 | def expand_ids(self, ids: list[int]): |
111 | return [ | 134 | return [ |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 218018b..f26b7f5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -108,6 +108,12 @@ def parse_args(): | |||
108 | help="Tag dropout probability.", | 108 | help="Tag dropout probability.", |
109 | ) | 109 | ) |
110 | parser.add_argument( | 110 | parser.add_argument( |
111 | "--vector_dropout", | ||
112 | type=int, | ||
113 | default=0.1, | ||
114 | help="Vector dropout probability.", | ||
115 | ) | ||
116 | parser.add_argument( | ||
111 | "--vector_shuffle", | 117 | "--vector_shuffle", |
112 | type=str, | 118 | type=str, |
113 | default="auto", | 119 | default="auto", |
@@ -556,6 +562,8 @@ def main(): | |||
556 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 562 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
557 | elif args.pretrained_model_name_or_path: | 563 | elif args.pretrained_model_name_or_path: |
558 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 564 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
565 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
566 | tokenizer.set_dropout(args.vector_dropout) | ||
559 | 567 | ||
560 | # Load models and create wrapper for stable diffusion | 568 | # Load models and create wrapper for stable diffusion |
561 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 569 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -826,6 +834,12 @@ def main(): | |||
826 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 834 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
827 | val_steps = num_val_steps_per_epoch * num_epochs | 835 | val_steps = num_val_steps_per_epoch * num_epochs |
828 | 836 | ||
837 | def on_train(): | ||
838 | tokenizer.train() | ||
839 | |||
840 | def on_eval(): | ||
841 | tokenizer.eval() | ||
842 | |||
829 | def loop(batch): | 843 | def loop(batch): |
830 | # Convert images to latent space | 844 | # Convert images to latent space |
831 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 845 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
@@ -898,8 +912,8 @@ def main(): | |||
898 | train_dataloader, | 912 | train_dataloader, |
899 | val_dataloader, | 913 | val_dataloader, |
900 | loop, | 914 | loop, |
901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 915 | on_train=tokenizer.train, |
902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 916 | on_eval=tokenizer.eval, |
903 | ) | 917 | ) |
904 | lr_finder.run(end_lr=1e2) | 918 | lr_finder.run(end_lr=1e2) |
905 | 919 | ||
@@ -953,7 +967,7 @@ def main(): | |||
953 | disable=not accelerator.is_local_main_process, | 967 | disable=not accelerator.is_local_main_process, |
954 | dynamic_ncols=True | 968 | dynamic_ncols=True |
955 | ) | 969 | ) |
956 | local_progress_bar.set_description("Epoch X / Y") | 970 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
957 | 971 | ||
958 | global_progress_bar = tqdm( | 972 | global_progress_bar = tqdm( |
959 | range(args.max_train_steps + val_steps), | 973 | range(args.max_train_steps + val_steps), |
@@ -976,7 +990,7 @@ def main(): | |||
976 | text_encoder.train() | 990 | text_encoder.train() |
977 | elif epoch == args.train_text_encoder_epochs: | 991 | elif epoch == args.train_text_encoder_epochs: |
978 | text_encoder.requires_grad_(False) | 992 | text_encoder.requires_grad_(False) |
979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 993 | on_train() |
980 | 994 | ||
981 | for step, batch in enumerate(train_dataloader): | 995 | for step, batch in enumerate(train_dataloader): |
982 | with accelerator.accumulate(unet): | 996 | with accelerator.accumulate(unet): |
@@ -1030,7 +1044,7 @@ def main(): | |||
1030 | 1044 | ||
1031 | unet.eval() | 1045 | unet.eval() |
1032 | text_encoder.eval() | 1046 | text_encoder.eval() |
1033 | tokenizer.set_use_vector_shuffle(False) | 1047 | on_eval() |
1034 | 1048 | ||
1035 | cur_loss_val = AverageMeter() | 1049 | cur_loss_val = AverageMeter() |
1036 | cur_acc_val = AverageMeter() | 1050 | cur_acc_val = AverageMeter() |
diff --git a/train_ti.py b/train_ti.py index 102c0fa..cacbbc7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -155,6 +155,12 @@ def parse_args(): | |||
155 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
156 | ) | 156 | ) |
157 | parser.add_argument( | 157 | parser.add_argument( |
158 | "--vector_dropout", | ||
159 | type=int, | ||
160 | default=0.1, | ||
161 | help="Vector dropout probability.", | ||
162 | ) | ||
163 | parser.add_argument( | ||
158 | "--vector_shuffle", | 164 | "--vector_shuffle", |
159 | type=str, | 165 | type=str, |
160 | default="auto", | 166 | default="auto", |
@@ -526,6 +532,8 @@ def main(): | |||
526 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 532 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
527 | elif args.pretrained_model_name_or_path: | 533 | elif args.pretrained_model_name_or_path: |
528 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 534 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
535 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
536 | tokenizer.set_dropout(args.vector_dropout) | ||
529 | 537 | ||
530 | # Load models and create wrapper for stable diffusion | 538 | # Load models and create wrapper for stable diffusion |
531 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 539 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -777,6 +785,12 @@ def main(): | |||
777 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 785 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
778 | val_steps = num_val_steps_per_epoch * num_epochs | 786 | val_steps = num_val_steps_per_epoch * num_epochs |
779 | 787 | ||
788 | def on_train(): | ||
789 | tokenizer.train() | ||
790 | |||
791 | def on_eval(): | ||
792 | tokenizer.eval() | ||
793 | |||
780 | def loop(batch): | 794 | def loop(batch): |
781 | # Convert images to latent space | 795 | # Convert images to latent space |
782 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 796 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
@@ -850,8 +864,8 @@ def main(): | |||
850 | train_dataloader, | 864 | train_dataloader, |
851 | val_dataloader, | 865 | val_dataloader, |
852 | loop, | 866 | loop, |
853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 867 | on_train=on_train, |
854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 868 | on_eval=on_eval, |
855 | ) | 869 | ) |
856 | lr_finder.run(end_lr=1e2) | 870 | lr_finder.run(end_lr=1e2) |
857 | 871 | ||
@@ -903,7 +917,7 @@ def main(): | |||
903 | disable=not accelerator.is_local_main_process, | 917 | disable=not accelerator.is_local_main_process, |
904 | dynamic_ncols=True | 918 | dynamic_ncols=True |
905 | ) | 919 | ) |
906 | local_progress_bar.set_description("Epoch X / Y") | 920 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
907 | 921 | ||
908 | global_progress_bar = tqdm( | 922 | global_progress_bar = tqdm( |
909 | range(args.max_train_steps + val_steps), | 923 | range(args.max_train_steps + val_steps), |
@@ -922,7 +936,7 @@ def main(): | |||
922 | local_progress_bar.reset() | 936 | local_progress_bar.reset() |
923 | 937 | ||
924 | text_encoder.train() | 938 | text_encoder.train() |
925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 939 | on_train() |
926 | 940 | ||
927 | for step, batch in enumerate(train_dataloader): | 941 | for step, batch in enumerate(train_dataloader): |
928 | with accelerator.accumulate(text_encoder): | 942 | with accelerator.accumulate(text_encoder): |
@@ -963,7 +977,7 @@ def main(): | |||
963 | accelerator.wait_for_everyone() | 977 | accelerator.wait_for_everyone() |
964 | 978 | ||
965 | text_encoder.eval() | 979 | text_encoder.eval() |
966 | tokenizer.set_use_vector_shuffle(False) | 980 | on_eval() |
967 | 981 | ||
968 | cur_loss_val = AverageMeter() | 982 | cur_loss_val = AverageMeter() |
969 | cur_acc_val = AverageMeter() | 983 | cur_acc_val = AverageMeter() |
diff --git a/training/lr.py b/training/lr.py index acc01a2..37588b6 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -58,7 +58,11 @@ class LRFinder(): | |||
58 | losses = [] | 58 | losses = [] |
59 | accs = [] | 59 | accs = [] |
60 | 60 | ||
61 | lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) | 61 | lr_scheduler = get_exponential_schedule( |
62 | self.optimizer, | ||
63 | end_lr, | ||
64 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
65 | ) | ||
62 | 66 | ||
63 | steps = min(num_train_batches, len(self.train_dataloader)) | 67 | steps = min(num_train_batches, len(self.train_dataloader)) |
64 | steps += min(num_val_batches, len(self.val_dataloader)) | 68 | steps += min(num_val_batches, len(self.val_dataloader)) |
@@ -90,6 +94,7 @@ class LRFinder(): | |||
90 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
91 | 95 | ||
92 | self.optimizer.step() | 96 | self.optimizer.step() |
97 | lr_scheduler.step() | ||
93 | self.optimizer.zero_grad(set_to_none=True) | 98 | self.optimizer.zero_grad(set_to_none=True) |
94 | 99 | ||
95 | if self.accelerator.sync_gradients: | 100 | if self.accelerator.sync_gradients: |
@@ -109,8 +114,6 @@ class LRFinder(): | |||
109 | 114 | ||
110 | progress_bar.update(1) | 115 | progress_bar.update(1) |
111 | 116 | ||
112 | lr_scheduler.step() | ||
113 | |||
114 | loss = avg_loss.avg.item() | 117 | loss = avg_loss.avg.item() |
115 | acc = avg_acc.avg.item() | 118 | acc = avg_acc.avg.item() |
116 | 119 | ||