diff options
-rw-r--r-- | data/csv.py | 2 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 18 | ||||
-rw-r--r-- | train_dreambooth.py | 30 | ||||
-rw-r--r-- | train_ti.py | 36 | ||||
-rw-r--r-- | training/ti.py | 48 |
5 files changed, 74 insertions, 60 deletions
diff --git a/data/csv.py b/data/csv.py index 803271b..af36d9e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -151,7 +151,7 @@ class CSVDataModule(): | |||
151 | 151 | ||
152 | num_images = len(items) | 152 | num_images = len(items) |
153 | 153 | ||
154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) | 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.2) |
155 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
156 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
157 | 157 | ||
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index a3e6e70..37d69a9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
15 | def __init__(self, *args, **kwargs): | 15 | def __init__(self, *args, **kwargs): |
16 | super().__init__(*args, **kwargs) | 16 | super().__init__(*args, **kwargs) |
17 | self.token_map: dict[int, list[int]] = {} | 17 | self.token_map: dict[int, list[int]] = {} |
18 | self.vector_shuffle = False | ||
19 | |||
20 | def set_use_vector_shuffle(self, enable: bool): | ||
21 | self.vector_shuffle = enable | ||
18 | 22 | ||
19 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
20 | if isinstance(new_tokens, list): | 24 | if isinstance(new_tokens, list): |
@@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
42 | 46 | ||
43 | return MultiCLIPTokenizerItem(new_tokens, ids) | 47 | return MultiCLIPTokenizerItem(new_tokens, ids) |
44 | 48 | ||
45 | def expand_id(self, id: int, vector_shuffle=True): | 49 | def expand_id(self, id: int): |
46 | if id in self.token_map: | 50 | if id in self.token_map: |
47 | tokens = self.token_map[id] | 51 | tokens = self.token_map[id] |
48 | 52 | ||
49 | if vector_shuffle and len(tokens) > 2: | 53 | if self.vector_shuffle and len(tokens) > 2: |
50 | subtokens = tokens[1:-1] | 54 | subtokens = tokens[1:-1] |
51 | np.random.shuffle(subtokens) | 55 | np.random.shuffle(subtokens) |
52 | tokens = tokens[:1] + subtokens + tokens[-1:] | 56 | tokens = tokens[:1] + subtokens + tokens[-1:] |
@@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
55 | else: | 59 | else: |
56 | return [id] | 60 | return [id] |
57 | 61 | ||
58 | def expand_ids(self, ids: list[int], vector_shuffle=True): | 62 | def expand_ids(self, ids: list[int]): |
59 | return [ | 63 | return [ |
60 | new_id | 64 | new_id |
61 | for id in ids | 65 | for id in ids |
62 | for new_id in self.expand_id(id, vector_shuffle) | 66 | for new_id in self.expand_id(id) |
63 | ] | 67 | ] |
64 | 68 | ||
65 | def _call_one(self, text, *args, vector_shuffle=True, **kwargs): | 69 | def _call_one(self, text, *args, **kwargs): |
66 | result = super()._call_one(text, *args, **kwargs) | 70 | result = super()._call_one(text, *args, **kwargs) |
67 | 71 | ||
68 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | 72 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) |
69 | 73 | ||
70 | if is_batched: | 74 | if is_batched: |
71 | result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] | 75 | result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] |
72 | else: | 76 | else: |
73 | result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) | 77 | result.input_ids = self.expand_ids(result.input_ids) |
74 | 78 | ||
75 | return result | 79 | return result |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8fd78f1..1ebcfe3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -232,6 +232,30 @@ def parse_args(): | |||
232 | help="Number of restart cycles in the lr scheduler (if supported)." | 232 | help="Number of restart cycles in the lr scheduler (if supported)." |
233 | ) | 233 | ) |
234 | parser.add_argument( | 234 | parser.add_argument( |
235 | "--lr_warmup_func", | ||
236 | type=str, | ||
237 | default="cos", | ||
238 | help='Choose between ["linear", "cos"]' | ||
239 | ) | ||
240 | parser.add_argument( | ||
241 | "--lr_warmup_exp", | ||
242 | type=int, | ||
243 | default=1, | ||
244 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
245 | ) | ||
246 | parser.add_argument( | ||
247 | "--lr_annealing_func", | ||
248 | type=str, | ||
249 | default="cos", | ||
250 | help='Choose between ["linear", "half_cos", "cos"]' | ||
251 | ) | ||
252 | parser.add_argument( | ||
253 | "--lr_annealing_exp", | ||
254 | type=int, | ||
255 | default=3, | ||
256 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
257 | ) | ||
258 | parser.add_argument( | ||
235 | "--use_ema", | 259 | "--use_ema", |
236 | action="store_true", | 260 | action="store_true", |
237 | default=True, | 261 | default=True, |
@@ -760,6 +784,10 @@ def main(): | |||
760 | lr_scheduler = get_one_cycle_schedule( | 784 | lr_scheduler = get_one_cycle_schedule( |
761 | optimizer=optimizer, | 785 | optimizer=optimizer, |
762 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 786 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
787 | warmup=args.lr_warmup_func, | ||
788 | annealing=args.lr_annealing_func, | ||
789 | warmup_exp=args.lr_warmup_exp, | ||
790 | annealing_exp=args.lr_annealing_exp, | ||
763 | ) | 791 | ) |
764 | elif args.lr_scheduler == "cosine_with_restarts": | 792 | elif args.lr_scheduler == "cosine_with_restarts": |
765 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 793 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
@@ -913,7 +941,7 @@ def main(): | |||
913 | else: | 941 | else: |
914 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 942 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
915 | 943 | ||
916 | acc = (model_pred == latents).float().mean() | 944 | acc = (model_pred == target).float().mean() |
917 | 945 | ||
918 | return loss, acc, bsz | 946 | return loss, acc, bsz |
919 | 947 | ||
diff --git a/train_ti.py b/train_ti.py index 19348e5..20a3190 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -225,6 +225,30 @@ def parse_args(): | |||
225 | help="Number of restart cycles in the lr scheduler." | 225 | help="Number of restart cycles in the lr scheduler." |
226 | ) | 226 | ) |
227 | parser.add_argument( | 227 | parser.add_argument( |
228 | "--lr_warmup_func", | ||
229 | type=str, | ||
230 | default="cos", | ||
231 | help='Choose between ["linear", "cos"]' | ||
232 | ) | ||
233 | parser.add_argument( | ||
234 | "--lr_warmup_exp", | ||
235 | type=int, | ||
236 | default=1, | ||
237 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
238 | ) | ||
239 | parser.add_argument( | ||
240 | "--lr_annealing_func", | ||
241 | type=str, | ||
242 | default="cos", | ||
243 | help='Choose between ["linear", "half_cos", "cos"]' | ||
244 | ) | ||
245 | parser.add_argument( | ||
246 | "--lr_annealing_exp", | ||
247 | type=int, | ||
248 | default=2, | ||
249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
250 | ) | ||
251 | parser.add_argument( | ||
228 | "--use_8bit_adam", | 252 | "--use_8bit_adam", |
229 | action="store_true", | 253 | action="store_true", |
230 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 254 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -510,6 +534,8 @@ def main(): | |||
510 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
511 | args.pretrained_model_name_or_path, subfolder='scheduler') | 535 | args.pretrained_model_name_or_path, subfolder='scheduler') |
512 | 536 | ||
537 | tokenizer.set_use_vector_shuffle(True) | ||
538 | |||
513 | vae.enable_slicing() | 539 | vae.enable_slicing() |
514 | vae.set_use_memory_efficient_attention_xformers(True) | 540 | vae.set_use_memory_efficient_attention_xformers(True) |
515 | unet.set_use_memory_efficient_attention_xformers(True) | 541 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -559,7 +585,7 @@ def main(): | |||
559 | ) | 585 | ) |
560 | 586 | ||
561 | if args.find_lr: | 587 | if args.find_lr: |
562 | args.learning_rate = 1e2 | 588 | args.learning_rate = 1e3 |
563 | 589 | ||
564 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
565 | if args.use_8bit_adam: | 591 | if args.use_8bit_adam: |
@@ -706,6 +732,10 @@ def main(): | |||
706 | lr_scheduler = get_one_cycle_schedule( | 732 | lr_scheduler = get_one_cycle_schedule( |
707 | optimizer=optimizer, | 733 | optimizer=optimizer, |
708 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 734 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
735 | warmup=args.lr_warmup_func, | ||
736 | annealing=args.lr_annealing_func, | ||
737 | warmup_exp=args.lr_warmup_exp, | ||
738 | annealing_exp=args.lr_annealing_exp, | ||
709 | ) | 739 | ) |
710 | elif args.lr_scheduler == "cosine_with_restarts": | 740 | elif args.lr_scheduler == "cosine_with_restarts": |
711 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 741 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
@@ -796,13 +826,13 @@ def main(): | |||
796 | else: | 826 | else: |
797 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 827 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
798 | 828 | ||
799 | acc = (model_pred == latents).float().mean() | 829 | acc = (model_pred == target).float().mean() |
800 | 830 | ||
801 | return loss, acc, bsz | 831 | return loss, acc, bsz |
802 | 832 | ||
803 | if args.find_lr: | 833 | if args.find_lr: |
804 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
805 | lr_finder.run(min_lr=1e-6, num_train_batches=1) | 835 | lr_finder.run(min_lr=1e-4) |
806 | 836 | ||
807 | plt.savefig(basepath.joinpath("lr.png")) | 837 | plt.savefig(basepath.joinpath("lr.png")) |
808 | plt.close() | 838 | plt.close() |
diff --git a/training/ti.py b/training/ti.py deleted file mode 100644 index 031fe48..0000000 --- a/training/ti.py +++ /dev/null | |||
@@ -1,48 +0,0 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | from transformers.models.clip import CLIPTextModel, CLIPTextConfig | ||
7 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
8 | |||
9 | |||
10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | ||
11 | text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) | ||
12 | text_encoder.text_model.embeddings = text_embeddings | ||
13 | |||
14 | |||
15 | class TrainableEmbeddings(CLIPTextEmbeddings): | ||
16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): | ||
17 | super().__init__(config) | ||
18 | |||
19 | self.token_embedding = embeddings.token_embedding | ||
20 | self.position_embedding = embeddings.position_embedding | ||
21 | |||
22 | self.train_indices = torch.tensor(new_ids) | ||
23 | |||
24 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
25 | self.trainable_embedding.weight.data.zero_() | ||
26 | self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] | ||
27 | |||
28 | def forward( | ||
29 | self, | ||
30 | input_ids: Optional[torch.LongTensor] = None, | ||
31 | position_ids: Optional[torch.LongTensor] = None, | ||
32 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
33 | ) -> torch.Tensor: | ||
34 | device = input_ids.device | ||
35 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
36 | |||
37 | if position_ids is None: | ||
38 | position_ids = self.position_ids[:, :seq_length] | ||
39 | |||
40 | if inputs_embeds is None: | ||
41 | mask = torch.isin(input_ids, self.train_indices.to(device)) | ||
42 | inputs_embeds = self.token_embedding(input_ids) | ||
43 | inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] | ||
44 | |||
45 | position_embeddings = self.position_embedding(position_ids) | ||
46 | embeddings = inputs_embeds + position_embeddings | ||
47 | |||
48 | return embeddings | ||