diff options
| -rw-r--r-- | models/clip/embeddings.py | 10 | ||||
| -rw-r--r-- | train_dreambooth.py | 16 | ||||
| -rw-r--r-- | train_lora.py | 16 | ||||
| -rw-r--r-- | train_ti.py | 25 | ||||
| -rw-r--r-- | training/functional.py | 3 | ||||
| -rw-r--r-- | training/optimization.py | 3 |
6 files changed, 67 insertions, 6 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4166dc6..9abd1bb 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -52,7 +52,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
| 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 54 | 54 | ||
| 55 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 55 | def add_embed( |
| 56 | self, | ||
| 57 | token_ids: Union[int, list[int]], | ||
| 58 | initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None, | ||
| 59 | initializer_noise: float = 0.0, | ||
| 60 | ): | ||
| 56 | if isinstance(token_ids, int): | 61 | if isinstance(token_ids, int): |
| 57 | token_ids = [token_ids] | 62 | token_ids = [token_ids] |
| 58 | 63 | ||
| @@ -73,6 +78,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 73 | dtype=self.temp_token_embedding.weight.dtype, | 78 | dtype=self.temp_token_embedding.weight.dtype, |
| 74 | ) | 79 | ) |
| 75 | 80 | ||
| 81 | if initializer_noise != 0: | ||
| 82 | initializer += torch.randn_like(initializer) * initializer_noise | ||
| 83 | |||
| 76 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 84 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 77 | 85 | ||
| 78 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 86 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 3a25efa..4456bd1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -13,6 +13,7 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from slugify import slugify | 15 | from slugify import slugify |
| 16 | import transformers | ||
| 16 | 17 | ||
| 17 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
| 18 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -305,7 +306,7 @@ def parse_args(): | |||
| 305 | "--optimizer", | 306 | "--optimizer", |
| 306 | type=str, | 307 | type=str, |
| 307 | default="dadan", | 308 | default="dadan", |
| 308 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 309 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
| 309 | ) | 310 | ) |
| 310 | parser.add_argument( | 311 | parser.add_argument( |
| 311 | "--dadaptation_d0", | 312 | "--dadaptation_d0", |
| @@ -535,6 +536,19 @@ def main(): | |||
| 535 | eps=args.adam_epsilon, | 536 | eps=args.adam_epsilon, |
| 536 | amsgrad=args.adam_amsgrad, | 537 | amsgrad=args.adam_amsgrad, |
| 537 | ) | 538 | ) |
| 539 | elif args.optimizer == 'adafactor': | ||
| 540 | create_optimizer = partial( | ||
| 541 | transformers.optimization.Adafactor, | ||
| 542 | beta1=args.adam_beta1, | ||
| 543 | weight_decay=args.adam_weight_decay, | ||
| 544 | scale_parameter=True, | ||
| 545 | relative_step=True, | ||
| 546 | warmup_init=True, | ||
| 547 | ) | ||
| 548 | |||
| 549 | args.lr_scheduler = "adafactor" | ||
| 550 | args.lr_min_lr = args.learning_rate | ||
| 551 | args.learning_rate = None | ||
| 538 | elif args.optimizer == 'dadam': | 552 | elif args.optimizer == 'dadam': |
| 539 | try: | 553 | try: |
| 540 | import dadaptation | 554 | import dadaptation |
diff --git a/train_lora.py b/train_lora.py index f74a438..f8dccae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -14,6 +14,7 @@ from accelerate.logging import get_logger | |||
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from peft import LoraConfig, LoraModel | 15 | from peft import LoraConfig, LoraModel |
| 16 | from slugify import slugify | 16 | from slugify import slugify |
| 17 | import transformers | ||
| 17 | 18 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -317,7 +318,7 @@ def parse_args(): | |||
| 317 | "--optimizer", | 318 | "--optimizer", |
| 318 | type=str, | 319 | type=str, |
| 319 | default="dadan", | 320 | default="dadan", |
| 320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 321 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
| 321 | ) | 322 | ) |
| 322 | parser.add_argument( | 323 | parser.add_argument( |
| 323 | "--dadaptation_d0", | 324 | "--dadaptation_d0", |
| @@ -567,6 +568,19 @@ def main(): | |||
| 567 | eps=args.adam_epsilon, | 568 | eps=args.adam_epsilon, |
| 568 | amsgrad=args.adam_amsgrad, | 569 | amsgrad=args.adam_amsgrad, |
| 569 | ) | 570 | ) |
| 571 | elif args.optimizer == 'adafactor': | ||
| 572 | create_optimizer = partial( | ||
| 573 | transformers.optimization.Adafactor, | ||
| 574 | beta1=args.adam_beta1, | ||
| 575 | weight_decay=args.adam_weight_decay, | ||
| 576 | scale_parameter=True, | ||
| 577 | relative_step=True, | ||
| 578 | warmup_init=True, | ||
| 579 | ) | ||
| 580 | |||
| 581 | args.lr_scheduler = "adafactor" | ||
| 582 | args.lr_min_lr = args.learning_rate | ||
| 583 | args.learning_rate = None | ||
| 570 | elif args.optimizer == 'dadam': | 584 | elif args.optimizer == 'dadam': |
| 571 | try: | 585 | try: |
| 572 | import dadaptation | 586 | import dadaptation |
diff --git a/train_ti.py b/train_ti.py index dd015f9..274a1ca 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -12,6 +12,7 @@ from accelerate import Accelerator | |||
| 12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | import transformers | ||
| 15 | 16 | ||
| 16 | from util.files import load_config, load_embeddings_from_dir | 17 | from util.files import load_config, load_embeddings_from_dir |
| 17 | from data.csv import VlpnDataModule, keyword_filter | 18 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -75,6 +76,12 @@ def parse_args(): | |||
| 75 | help="A token to use as initializer word." | 76 | help="A token to use as initializer word." |
| 76 | ) | 77 | ) |
| 77 | parser.add_argument( | 78 | parser.add_argument( |
| 79 | "--initializer_noise", | ||
| 80 | type=float, | ||
| 81 | default=0, | ||
| 82 | help="Noise to apply to the initializer word" | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 78 | "--alias_tokens", | 85 | "--alias_tokens", |
| 79 | type=str, | 86 | type=str, |
| 80 | nargs='*', | 87 | nargs='*', |
| @@ -323,7 +330,7 @@ def parse_args(): | |||
| 323 | "--optimizer", | 330 | "--optimizer", |
| 324 | type=str, | 331 | type=str, |
| 325 | default="dadan", | 332 | default="dadan", |
| 326 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 333 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
| 327 | ) | 334 | ) |
| 328 | parser.add_argument( | 335 | parser.add_argument( |
| 329 | "--dadaptation_d0", | 336 | "--dadaptation_d0", |
| @@ -659,6 +666,19 @@ def main(): | |||
| 659 | eps=args.adam_epsilon, | 666 | eps=args.adam_epsilon, |
| 660 | amsgrad=args.adam_amsgrad, | 667 | amsgrad=args.adam_amsgrad, |
| 661 | ) | 668 | ) |
| 669 | elif args.optimizer == 'adafactor': | ||
| 670 | create_optimizer = partial( | ||
| 671 | transformers.optimization.Adafactor, | ||
| 672 | beta1=args.adam_beta1, | ||
| 673 | weight_decay=args.adam_weight_decay, | ||
| 674 | scale_parameter=True, | ||
| 675 | relative_step=True, | ||
| 676 | warmup_init=True, | ||
| 677 | ) | ||
| 678 | |||
| 679 | args.lr_scheduler = "adafactor" | ||
| 680 | args.lr_min_lr = args.learning_rate | ||
| 681 | args.learning_rate = None | ||
| 662 | elif args.optimizer == 'dadam': | 682 | elif args.optimizer == 'dadam': |
| 663 | try: | 683 | try: |
| 664 | import dadaptation | 684 | import dadaptation |
| @@ -739,7 +759,8 @@ def main(): | |||
| 739 | embeddings=embeddings, | 759 | embeddings=embeddings, |
| 740 | placeholder_tokens=placeholder_tokens, | 760 | placeholder_tokens=placeholder_tokens, |
| 741 | initializer_tokens=initializer_tokens, | 761 | initializer_tokens=initializer_tokens, |
| 742 | num_vectors=num_vectors | 762 | num_vectors=num_vectors, |
| 763 | initializer_noise=args.initializer_noise, | ||
| 743 | ) | 764 | ) |
| 744 | 765 | ||
| 745 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 766 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |
diff --git a/training/functional.py b/training/functional.py index a2aa24e..ac43847 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -231,6 +231,7 @@ def add_placeholder_tokens( | |||
| 231 | placeholder_tokens: list[str], | 231 | placeholder_tokens: list[str], |
| 232 | initializer_tokens: list[str], | 232 | initializer_tokens: list[str], |
| 233 | num_vectors: Optional[Union[list[int], int]] = None, | 233 | num_vectors: Optional[Union[list[int], int]] = None, |
| 234 | initializer_noise: float = 0.0, | ||
| 234 | ): | 235 | ): |
| 235 | initializer_token_ids = [ | 236 | initializer_token_ids = [ |
| 236 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) |
| @@ -245,7 +246,7 @@ def add_placeholder_tokens( | |||
| 245 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) |
| 246 | 247 | ||
| 247 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
| 248 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) |
| 249 | 250 | ||
| 250 | return placeholder_token_ids, initializer_token_ids | 251 | return placeholder_token_ids, initializer_token_ids |
| 251 | 252 | ||
diff --git a/training/optimization.py b/training/optimization.py index 59ca950..53d0a6d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -6,6 +6,7 @@ import torch | |||
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 7 | 7 | ||
| 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
| 9 | import transformers | ||
| 9 | 10 | ||
| 10 | 11 | ||
| 11 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
| @@ -148,6 +149,8 @@ def get_scheduler( | |||
| 148 | num_training_steps=num_training_steps, | 149 | num_training_steps=num_training_steps, |
| 149 | num_cycles=cycles, | 150 | num_cycles=cycles, |
| 150 | ) | 151 | ) |
| 152 | elif id == "adafactor": | ||
| 153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | ||
| 151 | else: | 154 | else: |
| 152 | lr_scheduler = get_scheduler_( | 155 | lr_scheduler = get_scheduler_( |
| 153 | id, | 156 | id, |
