diff options
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 18 | ||||
| -rw-r--r-- | train_dreambooth.py | 30 | ||||
| -rw-r--r-- | train_ti.py | 36 | ||||
| -rw-r--r-- | training/ti.py | 48 |
5 files changed, 74 insertions, 60 deletions
diff --git a/data/csv.py b/data/csv.py index 803271b..af36d9e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -151,7 +151,7 @@ class CSVDataModule(): | |||
| 151 | 151 | ||
| 152 | num_images = len(items) | 152 | num_images = len(items) |
| 153 | 153 | ||
| 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) | 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.2) |
| 155 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
| 156 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
| 157 | 157 | ||
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index a3e6e70..37d69a9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 15 | def __init__(self, *args, **kwargs): | 15 | def __init__(self, *args, **kwargs): |
| 16 | super().__init__(*args, **kwargs) | 16 | super().__init__(*args, **kwargs) |
| 17 | self.token_map: dict[int, list[int]] = {} | 17 | self.token_map: dict[int, list[int]] = {} |
| 18 | self.vector_shuffle = False | ||
| 19 | |||
| 20 | def set_use_vector_shuffle(self, enable: bool): | ||
| 21 | self.vector_shuffle = enable | ||
| 18 | 22 | ||
| 19 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
| 20 | if isinstance(new_tokens, list): | 24 | if isinstance(new_tokens, list): |
| @@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 42 | 46 | ||
| 43 | return MultiCLIPTokenizerItem(new_tokens, ids) | 47 | return MultiCLIPTokenizerItem(new_tokens, ids) |
| 44 | 48 | ||
| 45 | def expand_id(self, id: int, vector_shuffle=True): | 49 | def expand_id(self, id: int): |
| 46 | if id in self.token_map: | 50 | if id in self.token_map: |
| 47 | tokens = self.token_map[id] | 51 | tokens = self.token_map[id] |
| 48 | 52 | ||
| 49 | if vector_shuffle and len(tokens) > 2: | 53 | if self.vector_shuffle and len(tokens) > 2: |
| 50 | subtokens = tokens[1:-1] | 54 | subtokens = tokens[1:-1] |
| 51 | np.random.shuffle(subtokens) | 55 | np.random.shuffle(subtokens) |
| 52 | tokens = tokens[:1] + subtokens + tokens[-1:] | 56 | tokens = tokens[:1] + subtokens + tokens[-1:] |
| @@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 55 | else: | 59 | else: |
| 56 | return [id] | 60 | return [id] |
| 57 | 61 | ||
| 58 | def expand_ids(self, ids: list[int], vector_shuffle=True): | 62 | def expand_ids(self, ids: list[int]): |
| 59 | return [ | 63 | return [ |
| 60 | new_id | 64 | new_id |
| 61 | for id in ids | 65 | for id in ids |
| 62 | for new_id in self.expand_id(id, vector_shuffle) | 66 | for new_id in self.expand_id(id) |
| 63 | ] | 67 | ] |
| 64 | 68 | ||
| 65 | def _call_one(self, text, *args, vector_shuffle=True, **kwargs): | 69 | def _call_one(self, text, *args, **kwargs): |
| 66 | result = super()._call_one(text, *args, **kwargs) | 70 | result = super()._call_one(text, *args, **kwargs) |
| 67 | 71 | ||
| 68 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | 72 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) |
| 69 | 73 | ||
| 70 | if is_batched: | 74 | if is_batched: |
| 71 | result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] | 75 | result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] |
| 72 | else: | 76 | else: |
| 73 | result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) | 77 | result.input_ids = self.expand_ids(result.input_ids) |
| 74 | 78 | ||
| 75 | return result | 79 | return result |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8fd78f1..1ebcfe3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -232,6 +232,30 @@ def parse_args(): | |||
| 232 | help="Number of restart cycles in the lr scheduler (if supported)." | 232 | help="Number of restart cycles in the lr scheduler (if supported)." |
| 233 | ) | 233 | ) |
| 234 | parser.add_argument( | 234 | parser.add_argument( |
| 235 | "--lr_warmup_func", | ||
| 236 | type=str, | ||
| 237 | default="cos", | ||
| 238 | help='Choose between ["linear", "cos"]' | ||
| 239 | ) | ||
| 240 | parser.add_argument( | ||
| 241 | "--lr_warmup_exp", | ||
| 242 | type=int, | ||
| 243 | default=1, | ||
| 244 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
| 245 | ) | ||
| 246 | parser.add_argument( | ||
| 247 | "--lr_annealing_func", | ||
| 248 | type=str, | ||
| 249 | default="cos", | ||
| 250 | help='Choose between ["linear", "half_cos", "cos"]' | ||
| 251 | ) | ||
| 252 | parser.add_argument( | ||
| 253 | "--lr_annealing_exp", | ||
| 254 | type=int, | ||
| 255 | default=3, | ||
| 256 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
| 257 | ) | ||
| 258 | parser.add_argument( | ||
| 235 | "--use_ema", | 259 | "--use_ema", |
| 236 | action="store_true", | 260 | action="store_true", |
| 237 | default=True, | 261 | default=True, |
| @@ -760,6 +784,10 @@ def main(): | |||
| 760 | lr_scheduler = get_one_cycle_schedule( | 784 | lr_scheduler = get_one_cycle_schedule( |
| 761 | optimizer=optimizer, | 785 | optimizer=optimizer, |
| 762 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 786 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 787 | warmup=args.lr_warmup_func, | ||
| 788 | annealing=args.lr_annealing_func, | ||
| 789 | warmup_exp=args.lr_warmup_exp, | ||
| 790 | annealing_exp=args.lr_annealing_exp, | ||
| 763 | ) | 791 | ) |
| 764 | elif args.lr_scheduler == "cosine_with_restarts": | 792 | elif args.lr_scheduler == "cosine_with_restarts": |
| 765 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 793 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| @@ -913,7 +941,7 @@ def main(): | |||
| 913 | else: | 941 | else: |
| 914 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 942 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 915 | 943 | ||
| 916 | acc = (model_pred == latents).float().mean() | 944 | acc = (model_pred == target).float().mean() |
| 917 | 945 | ||
| 918 | return loss, acc, bsz | 946 | return loss, acc, bsz |
| 919 | 947 | ||
diff --git a/train_ti.py b/train_ti.py index 19348e5..20a3190 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -225,6 +225,30 @@ def parse_args(): | |||
| 225 | help="Number of restart cycles in the lr scheduler." | 225 | help="Number of restart cycles in the lr scheduler." |
| 226 | ) | 226 | ) |
| 227 | parser.add_argument( | 227 | parser.add_argument( |
| 228 | "--lr_warmup_func", | ||
| 229 | type=str, | ||
| 230 | default="cos", | ||
| 231 | help='Choose between ["linear", "cos"]' | ||
| 232 | ) | ||
| 233 | parser.add_argument( | ||
| 234 | "--lr_warmup_exp", | ||
| 235 | type=int, | ||
| 236 | default=1, | ||
| 237 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
| 238 | ) | ||
| 239 | parser.add_argument( | ||
| 240 | "--lr_annealing_func", | ||
| 241 | type=str, | ||
| 242 | default="cos", | ||
| 243 | help='Choose between ["linear", "half_cos", "cos"]' | ||
| 244 | ) | ||
| 245 | parser.add_argument( | ||
| 246 | "--lr_annealing_exp", | ||
| 247 | type=int, | ||
| 248 | default=2, | ||
| 249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
| 250 | ) | ||
| 251 | parser.add_argument( | ||
| 228 | "--use_8bit_adam", | 252 | "--use_8bit_adam", |
| 229 | action="store_true", | 253 | action="store_true", |
| 230 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 254 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -510,6 +534,8 @@ def main(): | |||
| 510 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 511 | args.pretrained_model_name_or_path, subfolder='scheduler') | 535 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 512 | 536 | ||
| 537 | tokenizer.set_use_vector_shuffle(True) | ||
| 538 | |||
| 513 | vae.enable_slicing() | 539 | vae.enable_slicing() |
| 514 | vae.set_use_memory_efficient_attention_xformers(True) | 540 | vae.set_use_memory_efficient_attention_xformers(True) |
| 515 | unet.set_use_memory_efficient_attention_xformers(True) | 541 | unet.set_use_memory_efficient_attention_xformers(True) |
| @@ -559,7 +585,7 @@ def main(): | |||
| 559 | ) | 585 | ) |
| 560 | 586 | ||
| 561 | if args.find_lr: | 587 | if args.find_lr: |
| 562 | args.learning_rate = 1e2 | 588 | args.learning_rate = 1e3 |
| 563 | 589 | ||
| 564 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 565 | if args.use_8bit_adam: | 591 | if args.use_8bit_adam: |
| @@ -706,6 +732,10 @@ def main(): | |||
| 706 | lr_scheduler = get_one_cycle_schedule( | 732 | lr_scheduler = get_one_cycle_schedule( |
| 707 | optimizer=optimizer, | 733 | optimizer=optimizer, |
| 708 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 734 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 735 | warmup=args.lr_warmup_func, | ||
| 736 | annealing=args.lr_annealing_func, | ||
| 737 | warmup_exp=args.lr_warmup_exp, | ||
| 738 | annealing_exp=args.lr_annealing_exp, | ||
| 709 | ) | 739 | ) |
| 710 | elif args.lr_scheduler == "cosine_with_restarts": | 740 | elif args.lr_scheduler == "cosine_with_restarts": |
| 711 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 741 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| @@ -796,13 +826,13 @@ def main(): | |||
| 796 | else: | 826 | else: |
| 797 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 827 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 798 | 828 | ||
| 799 | acc = (model_pred == latents).float().mean() | 829 | acc = (model_pred == target).float().mean() |
| 800 | 830 | ||
| 801 | return loss, acc, bsz | 831 | return loss, acc, bsz |
| 802 | 832 | ||
| 803 | if args.find_lr: | 833 | if args.find_lr: |
| 804 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
| 805 | lr_finder.run(min_lr=1e-6, num_train_batches=1) | 835 | lr_finder.run(min_lr=1e-4) |
| 806 | 836 | ||
| 807 | plt.savefig(basepath.joinpath("lr.png")) | 837 | plt.savefig(basepath.joinpath("lr.png")) |
| 808 | plt.close() | 838 | plt.close() |
diff --git a/training/ti.py b/training/ti.py deleted file mode 100644 index 031fe48..0000000 --- a/training/ti.py +++ /dev/null | |||
| @@ -1,48 +0,0 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | |||
| 6 | from transformers.models.clip import CLIPTextModel, CLIPTextConfig | ||
| 7 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
| 8 | |||
| 9 | |||
| 10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | ||
| 11 | text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) | ||
| 12 | text_encoder.text_model.embeddings = text_embeddings | ||
| 13 | |||
| 14 | |||
| 15 | class TrainableEmbeddings(CLIPTextEmbeddings): | ||
| 16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): | ||
| 17 | super().__init__(config) | ||
| 18 | |||
| 19 | self.token_embedding = embeddings.token_embedding | ||
| 20 | self.position_embedding = embeddings.position_embedding | ||
| 21 | |||
| 22 | self.train_indices = torch.tensor(new_ids) | ||
| 23 | |||
| 24 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
| 25 | self.trainable_embedding.weight.data.zero_() | ||
| 26 | self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] | ||
| 27 | |||
| 28 | def forward( | ||
| 29 | self, | ||
| 30 | input_ids: Optional[torch.LongTensor] = None, | ||
| 31 | position_ids: Optional[torch.LongTensor] = None, | ||
| 32 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 33 | ) -> torch.Tensor: | ||
| 34 | device = input_ids.device | ||
| 35 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 36 | |||
| 37 | if position_ids is None: | ||
| 38 | position_ids = self.position_ids[:, :seq_length] | ||
| 39 | |||
| 40 | if inputs_embeds is None: | ||
| 41 | mask = torch.isin(input_ids, self.train_indices.to(device)) | ||
| 42 | inputs_embeds = self.token_embedding(input_ids) | ||
| 43 | inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] | ||
| 44 | |||
| 45 | position_embeddings = self.position_embedding(position_ids) | ||
| 46 | embeddings = inputs_embeds + position_embeddings | ||
| 47 | |||
| 48 | return embeddings | ||
