diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
commit | 01eee0cb24f52ca78761b78917959e1c247eae94 (patch) | |
tree | 914c0d3f5b888a4c344b30a861639c8e3d5259dd | |
parent | Update (diff) | |
download | textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2 textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip |
Add support for Adafactor, add TI initializer noise
-rw-r--r-- | models/clip/embeddings.py | 10 | ||||
-rw-r--r-- | train_dreambooth.py | 16 | ||||
-rw-r--r-- | train_lora.py | 16 | ||||
-rw-r--r-- | train_ti.py | 25 | ||||
-rw-r--r-- | training/functional.py | 3 | ||||
-rw-r--r-- | training/optimization.py | 3 |
6 files changed, 67 insertions, 6 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4166dc6..9abd1bb 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
54 | 54 | ||
55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed( |
56 | self, | ||
57 | token_ids: Union[int, list[int]], | ||
58 | initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None, | ||
59 | initializer_noise: float = 0.0, | ||
60 | ): | ||
56 | if isinstance(token_ids, int): | 61 | if isinstance(token_ids, int): |
57 | token_ids = [token_ids] | 62 | token_ids = [token_ids] |
58 | 63 | ||
@@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
73 | dtype=self.temp_token_embedding.weight.dtype, | 78 | dtype=self.temp_token_embedding.weight.dtype, |
74 | ) | 79 | ) |
75 | 80 | ||
81 | if initializer_noise != 0: | ||
82 | initializer += torch.randn_like(initializer) * initializer_noise | ||
83 | |||
76 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 84 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
77 | 85 | ||
78 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 86 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 3a25efa..4456bd1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -13,6 +13,7 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from slugify import slugify | 15 | from slugify import slugify |
16 | import transformers | ||
16 | 17 | ||
17 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
18 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
@@ -305,7 +306,7 @@ def parse_args(): | |||
305 | "--optimizer", | 306 | "--optimizer", |
306 | type=str, | 307 | type=str, |
307 | default="dadan", | 308 | default="dadan", |
308 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 309 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
309 | ) | 310 | ) |
310 | parser.add_argument( | 311 | parser.add_argument( |
311 | "--dadaptation_d0", | 312 | "--dadaptation_d0", |
@@ -535,6 +536,19 @@ def main(): | |||
535 | eps=args.adam_epsilon, | 536 | eps=args.adam_epsilon, |
536 | amsgrad=args.adam_amsgrad, | 537 | amsgrad=args.adam_amsgrad, |
537 | ) | 538 | ) |
539 | elif args.optimizer == 'adafactor': | ||
540 | create_optimizer = partial( | ||
541 | transformers.optimization.Adafactor, | ||
542 | beta1=args.adam_beta1, | ||
543 | weight_decay=args.adam_weight_decay, | ||
544 | scale_parameter=True, | ||
545 | relative_step=True, | ||
546 | warmup_init=True, | ||
547 | ) | ||
548 | |||
549 | args.lr_scheduler = "adafactor" | ||
550 | args.lr_min_lr = args.learning_rate | ||
551 | args.learning_rate = None | ||
538 | elif args.optimizer == 'dadam': | 552 | elif args.optimizer == 'dadam': |
539 | try: | 553 | try: |
540 | import dadaptation | 554 | import dadaptation |
diff --git a/train_lora.py b/train_lora.py index f74a438..f8dccae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -14,6 +14,7 @@ from accelerate.logging import get_logger | |||
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from peft import LoraConfig, LoraModel | 15 | from peft import LoraConfig, LoraModel |
16 | from slugify import slugify | 16 | from slugify import slugify |
17 | import transformers | ||
17 | 18 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
@@ -317,7 +318,7 @@ def parse_args(): | |||
317 | "--optimizer", | 318 | "--optimizer", |
318 | type=str, | 319 | type=str, |
319 | default="dadan", | 320 | default="dadan", |
320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 321 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
321 | ) | 322 | ) |
322 | parser.add_argument( | 323 | parser.add_argument( |
323 | "--dadaptation_d0", | 324 | "--dadaptation_d0", |
@@ -567,6 +568,19 @@ def main(): | |||
567 | eps=args.adam_epsilon, | 568 | eps=args.adam_epsilon, |
568 | amsgrad=args.adam_amsgrad, | 569 | amsgrad=args.adam_amsgrad, |
569 | ) | 570 | ) |
571 | elif args.optimizer == 'adafactor': | ||
572 | create_optimizer = partial( | ||
573 | transformers.optimization.Adafactor, | ||
574 | beta1=args.adam_beta1, | ||
575 | weight_decay=args.adam_weight_decay, | ||
576 | scale_parameter=True, | ||
577 | relative_step=True, | ||
578 | warmup_init=True, | ||
579 | ) | ||
580 | |||
581 | args.lr_scheduler = "adafactor" | ||
582 | args.lr_min_lr = args.learning_rate | ||
583 | args.learning_rate = None | ||
570 | elif args.optimizer == 'dadam': | 584 | elif args.optimizer == 'dadam': |
571 | try: | 585 | try: |
572 | import dadaptation | 586 | import dadaptation |
diff --git a/train_ti.py b/train_ti.py index dd015f9..274a1ca 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -12,6 +12,7 @@ from accelerate import Accelerator | |||
12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from slugify import slugify | 14 | from slugify import slugify |
15 | import transformers | ||
15 | 16 | ||
16 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
17 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
@@ -75,6 +76,12 @@ def parse_args(): | |||
75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." |
76 | ) | 77 | ) |
77 | parser.add_argument( | 78 | parser.add_argument( |
79 | "--initializer_noise", | ||
80 | type=float, | ||
81 | default=0, | ||
82 | help="Noise to apply to the initializer word" | ||
83 | ) | ||
84 | parser.add_argument( | ||
78 | "--alias_tokens", | 85 | "--alias_tokens", |
79 | type=str, | 86 | type=str, |
80 | nargs='*', | 87 | nargs='*', |
@@ -323,7 +330,7 @@ def parse_args(): | |||
323 | "--optimizer", | 330 | "--optimizer", |
324 | type=str, | 331 | type=str, |
325 | default="dadan", | 332 | default="dadan", |
326 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 333 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
327 | ) | 334 | ) |
328 | parser.add_argument( | 335 | parser.add_argument( |
329 | "--dadaptation_d0", | 336 | "--dadaptation_d0", |
@@ -659,6 +666,19 @@ def main(): | |||
659 | eps=args.adam_epsilon, | 666 | eps=args.adam_epsilon, |
660 | amsgrad=args.adam_amsgrad, | 667 | amsgrad=args.adam_amsgrad, |
661 | ) | 668 | ) |
669 | elif args.optimizer == 'adafactor': | ||
670 | create_optimizer = partial( | ||
671 | transformers.optimization.Adafactor, | ||
672 | beta1=args.adam_beta1, | ||
673 | weight_decay=args.adam_weight_decay, | ||
674 | scale_parameter=True, | ||
675 | relative_step=True, | ||
676 | warmup_init=True, | ||
677 | ) | ||
678 | |||
679 | args.lr_scheduler = "adafactor" | ||
680 | args.lr_min_lr = args.learning_rate | ||
681 | args.learning_rate = None | ||
662 | elif args.optimizer == 'dadam': | 682 | elif args.optimizer == 'dadam': |
663 | try: | 683 | try: |
664 | import dadaptation | 684 | import dadaptation |
@@ -739,7 +759,8 @@ def main(): | |||
739 | embeddings=embeddings, | 759 | embeddings=embeddings, |
740 | placeholder_tokens=placeholder_tokens, | 760 | placeholder_tokens=placeholder_tokens, |
741 | initializer_tokens=initializer_tokens, | 761 | initializer_tokens=initializer_tokens, |
742 | num_vectors=num_vectors | 762 | num_vectors=num_vectors, |
763 | initializer_noise=args.initializer_noise, | ||
743 | ) | 764 | ) |
744 | 765 | ||
745 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 766 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |
diff --git a/training/functional.py b/training/functional.py index a2aa24e..ac43847 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -231,6 +231,7 @@ def add_placeholder_tokens( | |||
231 | placeholder_tokens: list[str], | 231 | placeholder_tokens: list[str], |
232 | initializer_tokens: list[str], | 232 | initializer_tokens: list[str], |
233 | num_vectors: Optional[Union[list[int], int]] = None, | 233 | num_vectors: Optional[Union[list[int], int]] = None, |
234 | initializer_noise: float = 0.0, | ||
234 | ): | 235 | ): |
235 | initializer_token_ids = [ | 236 | initializer_token_ids = [ |
236 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) |
@@ -245,7 +246,7 @@ def add_placeholder_tokens( | |||
245 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) |
246 | 247 | ||
247 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
248 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) |
249 | 250 | ||
250 | return placeholder_token_ids, initializer_token_ids | 251 | return placeholder_token_ids, initializer_token_ids |
251 | 252 | ||
diff --git a/training/optimization.py b/training/optimization.py index 59ca950..53d0a6d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,6 +6,7 @@ import torch | |||
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
9 | import transformers | ||
9 | 10 | ||
10 | 11 | ||
11 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
@@ -148,6 +149,8 @@ def get_scheduler( | |||
148 | num_training_steps=num_training_steps, | 149 | num_training_steps=num_training_steps, |
149 | num_cycles=cycles, | 150 | num_cycles=cycles, |
150 | ) | 151 | ) |
152 | elif id == "adafactor": | ||
153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | ||
151 | else: | 154 | else: |
152 | lr_scheduler = get_scheduler_( | 155 | lr_scheduler = get_scheduler_( |
153 | id, | 156 | id, |