summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py2
-rw-r--r--models/clip/tokenizer.py18
-rw-r--r--train_dreambooth.py30
-rw-r--r--train_ti.py36
-rw-r--r--training/ti.py48
5 files changed, 74 insertions, 60 deletions
diff --git a/data/csv.py b/data/csv.py
index 803271b..af36d9e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -151,7 +151,7 @@ class CSVDataModule():
151 151
152 num_images = len(items) 152 num_images = len(items)
153 153
154 valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) 154 valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.2)
155 valid_set_size = max(valid_set_size, 1) 155 valid_set_size = max(valid_set_size, 1)
156 train_set_size = num_images - valid_set_size 156 train_set_size = num_images - valid_set_size
157 157
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index a3e6e70..37d69a9 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer):
15 def __init__(self, *args, **kwargs): 15 def __init__(self, *args, **kwargs):
16 super().__init__(*args, **kwargs) 16 super().__init__(*args, **kwargs)
17 self.token_map: dict[int, list[int]] = {} 17 self.token_map: dict[int, list[int]] = {}
18 self.vector_shuffle = False
19
20 def set_use_vector_shuffle(self, enable: bool):
21 self.vector_shuffle = enable
18 22
19 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: 23 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
20 if isinstance(new_tokens, list): 24 if isinstance(new_tokens, list):
@@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer):
42 46
43 return MultiCLIPTokenizerItem(new_tokens, ids) 47 return MultiCLIPTokenizerItem(new_tokens, ids)
44 48
45 def expand_id(self, id: int, vector_shuffle=True): 49 def expand_id(self, id: int):
46 if id in self.token_map: 50 if id in self.token_map:
47 tokens = self.token_map[id] 51 tokens = self.token_map[id]
48 52
49 if vector_shuffle and len(tokens) > 2: 53 if self.vector_shuffle and len(tokens) > 2:
50 subtokens = tokens[1:-1] 54 subtokens = tokens[1:-1]
51 np.random.shuffle(subtokens) 55 np.random.shuffle(subtokens)
52 tokens = tokens[:1] + subtokens + tokens[-1:] 56 tokens = tokens[:1] + subtokens + tokens[-1:]
@@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer):
55 else: 59 else:
56 return [id] 60 return [id]
57 61
58 def expand_ids(self, ids: list[int], vector_shuffle=True): 62 def expand_ids(self, ids: list[int]):
59 return [ 63 return [
60 new_id 64 new_id
61 for id in ids 65 for id in ids
62 for new_id in self.expand_id(id, vector_shuffle) 66 for new_id in self.expand_id(id)
63 ] 67 ]
64 68
65 def _call_one(self, text, *args, vector_shuffle=True, **kwargs): 69 def _call_one(self, text, *args, **kwargs):
66 result = super()._call_one(text, *args, **kwargs) 70 result = super()._call_one(text, *args, **kwargs)
67 71
68 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) 72 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)
69 73
70 if is_batched: 74 if is_batched:
71 result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] 75 result.input_ids = [self.expand_ids(batch) for batch in result.input_ids]
72 else: 76 else:
73 result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) 77 result.input_ids = self.expand_ids(result.input_ids)
74 78
75 return result 79 return result
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8fd78f1..1ebcfe3 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -232,6 +232,30 @@ def parse_args():
232 help="Number of restart cycles in the lr scheduler (if supported)." 232 help="Number of restart cycles in the lr scheduler (if supported)."
233 ) 233 )
234 parser.add_argument( 234 parser.add_argument(
235 "--lr_warmup_func",
236 type=str,
237 default="cos",
238 help='Choose between ["linear", "cos"]'
239 )
240 parser.add_argument(
241 "--lr_warmup_exp",
242 type=int,
243 default=1,
244 help='If lr_warmup_func is "cos", exponent to modify the function'
245 )
246 parser.add_argument(
247 "--lr_annealing_func",
248 type=str,
249 default="cos",
250 help='Choose between ["linear", "half_cos", "cos"]'
251 )
252 parser.add_argument(
253 "--lr_annealing_exp",
254 type=int,
255 default=3,
256 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
257 )
258 parser.add_argument(
235 "--use_ema", 259 "--use_ema",
236 action="store_true", 260 action="store_true",
237 default=True, 261 default=True,
@@ -760,6 +784,10 @@ def main():
760 lr_scheduler = get_one_cycle_schedule( 784 lr_scheduler = get_one_cycle_schedule(
761 optimizer=optimizer, 785 optimizer=optimizer,
762 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 786 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
787 warmup=args.lr_warmup_func,
788 annealing=args.lr_annealing_func,
789 warmup_exp=args.lr_warmup_exp,
790 annealing_exp=args.lr_annealing_exp,
763 ) 791 )
764 elif args.lr_scheduler == "cosine_with_restarts": 792 elif args.lr_scheduler == "cosine_with_restarts":
765 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 793 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
@@ -913,7 +941,7 @@ def main():
913 else: 941 else:
914 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 942 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
915 943
916 acc = (model_pred == latents).float().mean() 944 acc = (model_pred == target).float().mean()
917 945
918 return loss, acc, bsz 946 return loss, acc, bsz
919 947
diff --git a/train_ti.py b/train_ti.py
index 19348e5..20a3190 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -225,6 +225,30 @@ def parse_args():
225 help="Number of restart cycles in the lr scheduler." 225 help="Number of restart cycles in the lr scheduler."
226 ) 226 )
227 parser.add_argument( 227 parser.add_argument(
228 "--lr_warmup_func",
229 type=str,
230 default="cos",
231 help='Choose between ["linear", "cos"]'
232 )
233 parser.add_argument(
234 "--lr_warmup_exp",
235 type=int,
236 default=1,
237 help='If lr_warmup_func is "cos", exponent to modify the function'
238 )
239 parser.add_argument(
240 "--lr_annealing_func",
241 type=str,
242 default="cos",
243 help='Choose between ["linear", "half_cos", "cos"]'
244 )
245 parser.add_argument(
246 "--lr_annealing_exp",
247 type=int,
248 default=2,
249 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
250 )
251 parser.add_argument(
228 "--use_8bit_adam", 252 "--use_8bit_adam",
229 action="store_true", 253 action="store_true",
230 help="Whether or not to use 8-bit Adam from bitsandbytes." 254 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -510,6 +534,8 @@ def main():
510 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 534 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
511 args.pretrained_model_name_or_path, subfolder='scheduler') 535 args.pretrained_model_name_or_path, subfolder='scheduler')
512 536
537 tokenizer.set_use_vector_shuffle(True)
538
513 vae.enable_slicing() 539 vae.enable_slicing()
514 vae.set_use_memory_efficient_attention_xformers(True) 540 vae.set_use_memory_efficient_attention_xformers(True)
515 unet.set_use_memory_efficient_attention_xformers(True) 541 unet.set_use_memory_efficient_attention_xformers(True)
@@ -559,7 +585,7 @@ def main():
559 ) 585 )
560 586
561 if args.find_lr: 587 if args.find_lr:
562 args.learning_rate = 1e2 588 args.learning_rate = 1e3
563 589
564 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 590 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
565 if args.use_8bit_adam: 591 if args.use_8bit_adam:
@@ -706,6 +732,10 @@ def main():
706 lr_scheduler = get_one_cycle_schedule( 732 lr_scheduler = get_one_cycle_schedule(
707 optimizer=optimizer, 733 optimizer=optimizer,
708 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 734 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
735 warmup=args.lr_warmup_func,
736 annealing=args.lr_annealing_func,
737 warmup_exp=args.lr_warmup_exp,
738 annealing_exp=args.lr_annealing_exp,
709 ) 739 )
710 elif args.lr_scheduler == "cosine_with_restarts": 740 elif args.lr_scheduler == "cosine_with_restarts":
711 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 741 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
@@ -796,13 +826,13 @@ def main():
796 else: 826 else:
797 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 827 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
798 828
799 acc = (model_pred == latents).float().mean() 829 acc = (model_pred == target).float().mean()
800 830
801 return loss, acc, bsz 831 return loss, acc, bsz
802 832
803 if args.find_lr: 833 if args.find_lr:
804 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 834 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
805 lr_finder.run(min_lr=1e-6, num_train_batches=1) 835 lr_finder.run(min_lr=1e-4)
806 836
807 plt.savefig(basepath.joinpath("lr.png")) 837 plt.savefig(basepath.joinpath("lr.png"))
808 plt.close() 838 plt.close()
diff --git a/training/ti.py b/training/ti.py
deleted file mode 100644
index 031fe48..0000000
--- a/training/ti.py
+++ /dev/null
@@ -1,48 +0,0 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6from transformers.models.clip import CLIPTextModel, CLIPTextConfig
7from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
8
9
10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
11 text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids)
12 text_encoder.text_model.embeddings = text_embeddings
13
14
15class TrainableEmbeddings(CLIPTextEmbeddings):
16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]):
17 super().__init__(config)
18
19 self.token_embedding = embeddings.token_embedding
20 self.position_embedding = embeddings.position_embedding
21
22 self.train_indices = torch.tensor(new_ids)
23
24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
25 self.trainable_embedding.weight.data.zero_()
26 self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices]
27
28 def forward(
29 self,
30 input_ids: Optional[torch.LongTensor] = None,
31 position_ids: Optional[torch.LongTensor] = None,
32 inputs_embeds: Optional[torch.FloatTensor] = None,
33 ) -> torch.Tensor:
34 device = input_ids.device
35 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
36
37 if position_ids is None:
38 position_ids = self.position_ids[:, :seq_length]
39
40 if inputs_embeds is None:
41 mask = torch.isin(input_ids, self.train_indices.to(device))
42 inputs_embeds = self.token_embedding(input_ids)
43 inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask]
44
45 position_embeddings = self.position_embedding(position_ids)
46 embeddings = inputs_embeds + position_embeddings
47
48 return embeddings