diff options
| -rw-r--r-- | models/clip/tokenizer.py | 27 | ||||
| -rw-r--r-- | train_dreambooth.py | 24 | ||||
| -rw-r--r-- | train_ti.py | 24 | ||||
| -rw-r--r-- | training/lr.py | 9 |
4 files changed, 69 insertions, 15 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index bd0bd21..11a3df0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -6,6 +6,12 @@ import numpy as np | |||
| 6 | from transformers import CLIPTokenizer | 6 | from transformers import CLIPTokenizer |
| 7 | 7 | ||
| 8 | 8 | ||
| 9 | def dropout(tokens: list[int], dropout: float): | ||
| 10 | if dropout != 0: | ||
| 11 | tokens = [token for token in tokens if np.random.random() > dropout] | ||
| 12 | return tokens | ||
| 13 | |||
| 14 | |||
| 9 | def shuffle_all(tokens: list[int]): | 15 | def shuffle_all(tokens: list[int]): |
| 10 | if len(tokens) >= 2: | 16 | if len(tokens) >= 2: |
| 11 | tokens = copy.copy(tokens) | 17 | tokens = copy.copy(tokens) |
| @@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 59 | super().__init__(*args, **kwargs) | 65 | super().__init__(*args, **kwargs) |
| 60 | 66 | ||
| 61 | self.token_map: dict[int, list[int]] = {} | 67 | self.token_map: dict[int, list[int]] = {} |
| 62 | self.vector_shuffle = shuffle_none | 68 | self.is_training = False |
| 69 | self.vector_shuffle = shuffle_auto | ||
| 70 | self.dropout = 0 | ||
| 71 | |||
| 72 | def train(self): | ||
| 73 | self.is_training = True | ||
| 74 | |||
| 75 | def eval(self): | ||
| 76 | self.is_training = False | ||
| 77 | |||
| 78 | def set_dropout(self, dropout: float): | ||
| 79 | self.dropout = dropout | ||
| 63 | 80 | ||
| 64 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): | 81 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): |
| 65 | if algorithm == "leading": | 82 | if algorithm == "leading": |
| @@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 105 | return MultiCLIPTokenizerItem(new_tokens, ids) | 122 | return MultiCLIPTokenizerItem(new_tokens, ids) |
| 106 | 123 | ||
| 107 | def expand_id(self, id: int): | 124 | def expand_id(self, id: int): |
| 108 | return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] | 125 | if id in self.token_map: |
| 126 | ids = self.token_map[id] | ||
| 127 | if self.is_training: | ||
| 128 | ids = dropout(self.vector_shuffle(ids), self.dropout) | ||
| 129 | return ids | ||
| 130 | else: | ||
| 131 | return [id] | ||
| 109 | 132 | ||
| 110 | def expand_ids(self, ids: list[int]): | 133 | def expand_ids(self, ids: list[int]): |
| 111 | return [ | 134 | return [ |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 218018b..f26b7f5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -108,6 +108,12 @@ def parse_args(): | |||
| 108 | help="Tag dropout probability.", | 108 | help="Tag dropout probability.", |
| 109 | ) | 109 | ) |
| 110 | parser.add_argument( | 110 | parser.add_argument( |
| 111 | "--vector_dropout", | ||
| 112 | type=int, | ||
| 113 | default=0.1, | ||
| 114 | help="Vector dropout probability.", | ||
| 115 | ) | ||
| 116 | parser.add_argument( | ||
| 111 | "--vector_shuffle", | 117 | "--vector_shuffle", |
| 112 | type=str, | 118 | type=str, |
| 113 | default="auto", | 119 | default="auto", |
| @@ -556,6 +562,8 @@ def main(): | |||
| 556 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 562 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 557 | elif args.pretrained_model_name_or_path: | 563 | elif args.pretrained_model_name_or_path: |
| 558 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 564 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 565 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 566 | tokenizer.set_dropout(args.vector_dropout) | ||
| 559 | 567 | ||
| 560 | # Load models and create wrapper for stable diffusion | 568 | # Load models and create wrapper for stable diffusion |
| 561 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 569 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -826,6 +834,12 @@ def main(): | |||
| 826 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 834 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 827 | val_steps = num_val_steps_per_epoch * num_epochs | 835 | val_steps = num_val_steps_per_epoch * num_epochs |
| 828 | 836 | ||
| 837 | def on_train(): | ||
| 838 | tokenizer.train() | ||
| 839 | |||
| 840 | def on_eval(): | ||
| 841 | tokenizer.eval() | ||
| 842 | |||
| 829 | def loop(batch): | 843 | def loop(batch): |
| 830 | # Convert images to latent space | 844 | # Convert images to latent space |
| 831 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 845 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| @@ -898,8 +912,8 @@ def main(): | |||
| 898 | train_dataloader, | 912 | train_dataloader, |
| 899 | val_dataloader, | 913 | val_dataloader, |
| 900 | loop, | 914 | loop, |
| 901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 915 | on_train=tokenizer.train, |
| 902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 916 | on_eval=tokenizer.eval, |
| 903 | ) | 917 | ) |
| 904 | lr_finder.run(end_lr=1e2) | 918 | lr_finder.run(end_lr=1e2) |
| 905 | 919 | ||
| @@ -953,7 +967,7 @@ def main(): | |||
| 953 | disable=not accelerator.is_local_main_process, | 967 | disable=not accelerator.is_local_main_process, |
| 954 | dynamic_ncols=True | 968 | dynamic_ncols=True |
| 955 | ) | 969 | ) |
| 956 | local_progress_bar.set_description("Epoch X / Y") | 970 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
| 957 | 971 | ||
| 958 | global_progress_bar = tqdm( | 972 | global_progress_bar = tqdm( |
| 959 | range(args.max_train_steps + val_steps), | 973 | range(args.max_train_steps + val_steps), |
| @@ -976,7 +990,7 @@ def main(): | |||
| 976 | text_encoder.train() | 990 | text_encoder.train() |
| 977 | elif epoch == args.train_text_encoder_epochs: | 991 | elif epoch == args.train_text_encoder_epochs: |
| 978 | text_encoder.requires_grad_(False) | 992 | text_encoder.requires_grad_(False) |
| 979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 993 | on_train() |
| 980 | 994 | ||
| 981 | for step, batch in enumerate(train_dataloader): | 995 | for step, batch in enumerate(train_dataloader): |
| 982 | with accelerator.accumulate(unet): | 996 | with accelerator.accumulate(unet): |
| @@ -1030,7 +1044,7 @@ def main(): | |||
| 1030 | 1044 | ||
| 1031 | unet.eval() | 1045 | unet.eval() |
| 1032 | text_encoder.eval() | 1046 | text_encoder.eval() |
| 1033 | tokenizer.set_use_vector_shuffle(False) | 1047 | on_eval() |
| 1034 | 1048 | ||
| 1035 | cur_loss_val = AverageMeter() | 1049 | cur_loss_val = AverageMeter() |
| 1036 | cur_acc_val = AverageMeter() | 1050 | cur_acc_val = AverageMeter() |
diff --git a/train_ti.py b/train_ti.py index 102c0fa..cacbbc7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -155,6 +155,12 @@ def parse_args(): | |||
| 155 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
| 156 | ) | 156 | ) |
| 157 | parser.add_argument( | 157 | parser.add_argument( |
| 158 | "--vector_dropout", | ||
| 159 | type=int, | ||
| 160 | default=0.1, | ||
| 161 | help="Vector dropout probability.", | ||
| 162 | ) | ||
| 163 | parser.add_argument( | ||
| 158 | "--vector_shuffle", | 164 | "--vector_shuffle", |
| 159 | type=str, | 165 | type=str, |
| 160 | default="auto", | 166 | default="auto", |
| @@ -526,6 +532,8 @@ def main(): | |||
| 526 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 532 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 527 | elif args.pretrained_model_name_or_path: | 533 | elif args.pretrained_model_name_or_path: |
| 528 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 534 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 535 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 536 | tokenizer.set_dropout(args.vector_dropout) | ||
| 529 | 537 | ||
| 530 | # Load models and create wrapper for stable diffusion | 538 | # Load models and create wrapper for stable diffusion |
| 531 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 539 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -777,6 +785,12 @@ def main(): | |||
| 777 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 785 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 778 | val_steps = num_val_steps_per_epoch * num_epochs | 786 | val_steps = num_val_steps_per_epoch * num_epochs |
| 779 | 787 | ||
| 788 | def on_train(): | ||
| 789 | tokenizer.train() | ||
| 790 | |||
| 791 | def on_eval(): | ||
| 792 | tokenizer.eval() | ||
| 793 | |||
| 780 | def loop(batch): | 794 | def loop(batch): |
| 781 | # Convert images to latent space | 795 | # Convert images to latent space |
| 782 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 796 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| @@ -850,8 +864,8 @@ def main(): | |||
| 850 | train_dataloader, | 864 | train_dataloader, |
| 851 | val_dataloader, | 865 | val_dataloader, |
| 852 | loop, | 866 | loop, |
| 853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 867 | on_train=on_train, |
| 854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 868 | on_eval=on_eval, |
| 855 | ) | 869 | ) |
| 856 | lr_finder.run(end_lr=1e2) | 870 | lr_finder.run(end_lr=1e2) |
| 857 | 871 | ||
| @@ -903,7 +917,7 @@ def main(): | |||
| 903 | disable=not accelerator.is_local_main_process, | 917 | disable=not accelerator.is_local_main_process, |
| 904 | dynamic_ncols=True | 918 | dynamic_ncols=True |
| 905 | ) | 919 | ) |
| 906 | local_progress_bar.set_description("Epoch X / Y") | 920 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
| 907 | 921 | ||
| 908 | global_progress_bar = tqdm( | 922 | global_progress_bar = tqdm( |
| 909 | range(args.max_train_steps + val_steps), | 923 | range(args.max_train_steps + val_steps), |
| @@ -922,7 +936,7 @@ def main(): | |||
| 922 | local_progress_bar.reset() | 936 | local_progress_bar.reset() |
| 923 | 937 | ||
| 924 | text_encoder.train() | 938 | text_encoder.train() |
| 925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 939 | on_train() |
| 926 | 940 | ||
| 927 | for step, batch in enumerate(train_dataloader): | 941 | for step, batch in enumerate(train_dataloader): |
| 928 | with accelerator.accumulate(text_encoder): | 942 | with accelerator.accumulate(text_encoder): |
| @@ -963,7 +977,7 @@ def main(): | |||
| 963 | accelerator.wait_for_everyone() | 977 | accelerator.wait_for_everyone() |
| 964 | 978 | ||
| 965 | text_encoder.eval() | 979 | text_encoder.eval() |
| 966 | tokenizer.set_use_vector_shuffle(False) | 980 | on_eval() |
| 967 | 981 | ||
| 968 | cur_loss_val = AverageMeter() | 982 | cur_loss_val = AverageMeter() |
| 969 | cur_acc_val = AverageMeter() | 983 | cur_acc_val = AverageMeter() |
diff --git a/training/lr.py b/training/lr.py index acc01a2..37588b6 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -58,7 +58,11 @@ class LRFinder(): | |||
| 58 | losses = [] | 58 | losses = [] |
| 59 | accs = [] | 59 | accs = [] |
| 60 | 60 | ||
| 61 | lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) | 61 | lr_scheduler = get_exponential_schedule( |
| 62 | self.optimizer, | ||
| 63 | end_lr, | ||
| 64 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
| 65 | ) | ||
| 62 | 66 | ||
| 63 | steps = min(num_train_batches, len(self.train_dataloader)) | 67 | steps = min(num_train_batches, len(self.train_dataloader)) |
| 64 | steps += min(num_val_batches, len(self.val_dataloader)) | 68 | steps += min(num_val_batches, len(self.val_dataloader)) |
| @@ -90,6 +94,7 @@ class LRFinder(): | |||
| 90 | self.accelerator.backward(loss) | 94 | self.accelerator.backward(loss) |
| 91 | 95 | ||
| 92 | self.optimizer.step() | 96 | self.optimizer.step() |
| 97 | lr_scheduler.step() | ||
| 93 | self.optimizer.zero_grad(set_to_none=True) | 98 | self.optimizer.zero_grad(set_to_none=True) |
| 94 | 99 | ||
| 95 | if self.accelerator.sync_gradients: | 100 | if self.accelerator.sync_gradients: |
| @@ -109,8 +114,6 @@ class LRFinder(): | |||
| 109 | 114 | ||
| 110 | progress_bar.update(1) | 115 | progress_bar.update(1) |
| 111 | 116 | ||
| 112 | lr_scheduler.step() | ||
| 113 | |||
| 114 | loss = avg_loss.avg.item() | 117 | loss = avg_loss.avg.item() |
| 115 | acc = avg_acc.avg.item() | 118 | acc = avg_acc.avg.item() |
| 116 | 119 | ||
