summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py10
-rw-r--r--train_dreambooth.py16
-rw-r--r--train_lora.py16
-rw-r--r--train_ti.py25
-rw-r--r--training/functional.py3
-rw-r--r--training/optimization.py3
6 files changed, 67 insertions, 6 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 4166dc6..9abd1bb 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) 52 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(
56 self,
57 token_ids: Union[int, list[int]],
58 initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None,
59 initializer_noise: float = 0.0,
60 ):
56 if isinstance(token_ids, int): 61 if isinstance(token_ids, int):
57 token_ids = [token_ids] 62 token_ids = [token_ids]
58 63
@@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
73 dtype=self.temp_token_embedding.weight.dtype, 78 dtype=self.temp_token_embedding.weight.dtype,
74 ) 79 )
75 80
81 if initializer_noise != 0:
82 initializer += torch.randn_like(initializer) * initializer_noise
83
76 token_ids = torch.tensor(token_ids, dtype=torch.long) 84 token_ids = torch.tensor(token_ids, dtype=torch.long)
77 85
78 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 86 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 3a25efa..4456bd1 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -13,6 +13,7 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from slugify import slugify 15from slugify import slugify
16import transformers
16 17
17from util.files import load_config, load_embeddings_from_dir 18from util.files import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
@@ -305,7 +306,7 @@ def parse_args():
305 "--optimizer", 306 "--optimizer",
306 type=str, 307 type=str,
307 default="dadan", 308 default="dadan",
308 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 309 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
309 ) 310 )
310 parser.add_argument( 311 parser.add_argument(
311 "--dadaptation_d0", 312 "--dadaptation_d0",
@@ -535,6 +536,19 @@ def main():
535 eps=args.adam_epsilon, 536 eps=args.adam_epsilon,
536 amsgrad=args.adam_amsgrad, 537 amsgrad=args.adam_amsgrad,
537 ) 538 )
539 elif args.optimizer == 'adafactor':
540 create_optimizer = partial(
541 transformers.optimization.Adafactor,
542 beta1=args.adam_beta1,
543 weight_decay=args.adam_weight_decay,
544 scale_parameter=True,
545 relative_step=True,
546 warmup_init=True,
547 )
548
549 args.lr_scheduler = "adafactor"
550 args.lr_min_lr = args.learning_rate
551 args.learning_rate = None
538 elif args.optimizer == 'dadam': 552 elif args.optimizer == 'dadam':
539 try: 553 try:
540 import dadaptation 554 import dadaptation
diff --git a/train_lora.py b/train_lora.py
index f74a438..f8dccae 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -14,6 +14,7 @@ from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from peft import LoraConfig, LoraModel 15from peft import LoraConfig, LoraModel
16from slugify import slugify 16from slugify import slugify
17import transformers
17 18
18from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
@@ -317,7 +318,7 @@ def parse_args():
317 "--optimizer", 318 "--optimizer",
318 type=str, 319 type=str,
319 default="dadan", 320 default="dadan",
320 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 321 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
321 ) 322 )
322 parser.add_argument( 323 parser.add_argument(
323 "--dadaptation_d0", 324 "--dadaptation_d0",
@@ -567,6 +568,19 @@ def main():
567 eps=args.adam_epsilon, 568 eps=args.adam_epsilon,
568 amsgrad=args.adam_amsgrad, 569 amsgrad=args.adam_amsgrad,
569 ) 570 )
571 elif args.optimizer == 'adafactor':
572 create_optimizer = partial(
573 transformers.optimization.Adafactor,
574 beta1=args.adam_beta1,
575 weight_decay=args.adam_weight_decay,
576 scale_parameter=True,
577 relative_step=True,
578 warmup_init=True,
579 )
580
581 args.lr_scheduler = "adafactor"
582 args.lr_min_lr = args.learning_rate
583 args.learning_rate = None
570 elif args.optimizer == 'dadam': 584 elif args.optimizer == 'dadam':
571 try: 585 try:
572 import dadaptation 586 import dadaptation
diff --git a/train_ti.py b/train_ti.py
index dd015f9..274a1ca 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -12,6 +12,7 @@ from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15import transformers
15 16
16from util.files import load_config, load_embeddings_from_dir 17from util.files import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, keyword_filter 18from data.csv import VlpnDataModule, keyword_filter
@@ -75,6 +76,12 @@ def parse_args():
75 help="A token to use as initializer word." 76 help="A token to use as initializer word."
76 ) 77 )
77 parser.add_argument( 78 parser.add_argument(
79 "--initializer_noise",
80 type=float,
81 default=0,
82 help="Noise to apply to the initializer word"
83 )
84 parser.add_argument(
78 "--alias_tokens", 85 "--alias_tokens",
79 type=str, 86 type=str,
80 nargs='*', 87 nargs='*',
@@ -323,7 +330,7 @@ def parse_args():
323 "--optimizer", 330 "--optimizer",
324 type=str, 331 type=str,
325 default="dadan", 332 default="dadan",
326 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 333 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
327 ) 334 )
328 parser.add_argument( 335 parser.add_argument(
329 "--dadaptation_d0", 336 "--dadaptation_d0",
@@ -659,6 +666,19 @@ def main():
659 eps=args.adam_epsilon, 666 eps=args.adam_epsilon,
660 amsgrad=args.adam_amsgrad, 667 amsgrad=args.adam_amsgrad,
661 ) 668 )
669 elif args.optimizer == 'adafactor':
670 create_optimizer = partial(
671 transformers.optimization.Adafactor,
672 beta1=args.adam_beta1,
673 weight_decay=args.adam_weight_decay,
674 scale_parameter=True,
675 relative_step=True,
676 warmup_init=True,
677 )
678
679 args.lr_scheduler = "adafactor"
680 args.lr_min_lr = args.learning_rate
681 args.learning_rate = None
662 elif args.optimizer == 'dadam': 682 elif args.optimizer == 'dadam':
663 try: 683 try:
664 import dadaptation 684 import dadaptation
@@ -739,7 +759,8 @@ def main():
739 embeddings=embeddings, 759 embeddings=embeddings,
740 placeholder_tokens=placeholder_tokens, 760 placeholder_tokens=placeholder_tokens,
741 initializer_tokens=initializer_tokens, 761 initializer_tokens=initializer_tokens,
742 num_vectors=num_vectors 762 num_vectors=num_vectors,
763 initializer_noise=args.initializer_noise,
743 ) 764 )
744 765
745 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) 766 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids))
diff --git a/training/functional.py b/training/functional.py
index a2aa24e..ac43847 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -231,6 +231,7 @@ def add_placeholder_tokens(
231 placeholder_tokens: list[str], 231 placeholder_tokens: list[str],
232 initializer_tokens: list[str], 232 initializer_tokens: list[str],
233 num_vectors: Optional[Union[list[int], int]] = None, 233 num_vectors: Optional[Union[list[int], int]] = None,
234 initializer_noise: float = 0.0,
234): 235):
235 initializer_token_ids = [ 236 initializer_token_ids = [
236 tokenizer.encode(token, add_special_tokens=False) 237 tokenizer.encode(token, add_special_tokens=False)
@@ -245,7 +246,7 @@ def add_placeholder_tokens(
245 embeddings.resize(len(tokenizer)) 246 embeddings.resize(len(tokenizer))
246 247
247 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 248 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
248 embeddings.add_embed(placeholder_token_id, initializer_token_id) 249 embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise)
249 250
250 return placeholder_token_ids, initializer_token_ids 251 return placeholder_token_ids, initializer_token_ids
251 252
diff --git a/training/optimization.py b/training/optimization.py
index 59ca950..53d0a6d 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,6 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers
9 10
10 11
11class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -148,6 +149,8 @@ def get_scheduler(
148 num_training_steps=num_training_steps, 149 num_training_steps=num_training_steps,
149 num_cycles=cycles, 150 num_cycles=cycles,
150 ) 151 )
152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr)
151 else: 154 else:
152 lr_scheduler = get_scheduler_( 155 lr_scheduler = get_scheduler_(
153 id, 156 id,