summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 22:25:30 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 22:25:30 +0100
commit3e7fbb7dce321435bbbb81361debfbc499bf9231 (patch)
treee7d5cefd2eda9755ab58861862f1978c13386f0d
parentMore modularization (diff)
downloadtextual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.tar.gz
textual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.tar.bz2
textual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.zip
Reverted modularization mostly
-rw-r--r--train_dreambooth.py3
-rw-r--r--train_ti.py467
-rw-r--r--training/common.py264
-rw-r--r--training/modules/dreambooth.py0
-rw-r--r--training/modules/lora.py0
-rw-r--r--training/modules/ti.py284
-rw-r--r--training/optimization.py53
7 files changed, 458 insertions, 613 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c892ebf..2145e2b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -21,7 +21,8 @@ from slugify import slugify
21from util import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import VlpnDataModule, VlpnDataItem 23from data.csv import VlpnDataModule, VlpnDataItem
24from training.common import loss_step, train_loop, generate_class_images, get_scheduler 24from training.common import loss_step, train_loop, generate_class_images
25from training.optimization import get_scheduler
25from training.lr import LRFinder 26from training.lr import LRFinder
26from training.util import CheckpointerBase, save_args 27from training.util import CheckpointerBase, save_args
27from models.clip.embeddings import patch_managed_embeddings 28from models.clip.embeddings import patch_managed_embeddings
diff --git a/train_ti.py b/train_ti.py
index 3a55f40..61195f6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,15 +1,29 @@
1import argparse 1import argparse
2import datetime
3import logging
4from functools import partial
5from pathlib import Path
6from contextlib import contextmanager, nullcontext
2 7
3import torch 8import torch
4import torch.utils.checkpoint 9import torch.utils.checkpoint
5 10
11from accelerate import Accelerator
6from accelerate.logging import get_logger 12from accelerate.logging import get_logger
7 13from accelerate.utils import LoggerType, set_seed
8from util import load_config 14from diffusers import AutoencoderKL, UNet2DConditionModel
9from data.csv import VlpnDataItem 15import matplotlib.pyplot as plt
10from training.common import train_setup 16from transformers import CLIPTextModel
11from training.modules.ti import train_ti 17from slugify import slugify
12from training.util import save_args 18
19from util import load_config, load_embeddings_from_dir
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from data.csv import VlpnDataModule, VlpnDataItem
22from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
23from training.optimization import get_scheduler
24from training.lr import LRFinder
25from training.util import CheckpointerBase, EMAModel, save_args
26from models.clip.tokenizer import MultiCLIPTokenizer
13 27
14logger = get_logger(__name__) 28logger = get_logger(__name__)
15 29
@@ -52,13 +66,13 @@ def parse_args():
52 help="The name of the current project.", 66 help="The name of the current project.",
53 ) 67 )
54 parser.add_argument( 68 parser.add_argument(
55 "--placeholder_token", 69 "--placeholder_tokens",
56 type=str, 70 type=str,
57 nargs='*', 71 nargs='*',
58 help="A token to use as a placeholder for the concept.", 72 help="A token to use as a placeholder for the concept.",
59 ) 73 )
60 parser.add_argument( 74 parser.add_argument(
61 "--initializer_token", 75 "--initializer_tokens",
62 type=str, 76 type=str,
63 nargs='*', 77 nargs='*',
64 help="A token to use as initializer word." 78 help="A token to use as initializer word."
@@ -439,29 +453,29 @@ def parse_args():
439 if args.project is None: 453 if args.project is None:
440 raise ValueError("You must specify --project") 454 raise ValueError("You must specify --project")
441 455
442 if isinstance(args.placeholder_token, str): 456 if isinstance(args.placeholder_tokens, str):
443 args.placeholder_token = [args.placeholder_token] 457 args.placeholder_tokens = [args.placeholder_tokens]
444 458
445 if len(args.placeholder_token) == 0: 459 if len(args.placeholder_tokens) == 0:
446 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 460 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)]
447 461
448 if isinstance(args.initializer_token, str): 462 if isinstance(args.initializer_tokens, str):
449 args.initializer_token = [args.initializer_token] * len(args.placeholder_token) 463 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
450 464
451 if len(args.initializer_token) == 0: 465 if len(args.initializer_tokens) == 0:
452 raise ValueError("You must specify --initializer_token") 466 raise ValueError("You must specify --initializer_tokens")
453 467
454 if len(args.placeholder_token) != len(args.initializer_token): 468 if len(args.placeholder_tokens) != len(args.initializer_tokens):
455 raise ValueError("--placeholder_token and --initializer_token must have the same number of items") 469 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
456 470
457 if args.num_vectors is None: 471 if args.num_vectors is None:
458 args.num_vectors = 1 472 args.num_vectors = 1
459 473
460 if isinstance(args.num_vectors, int): 474 if isinstance(args.num_vectors, int):
461 args.num_vectors = [args.num_vectors] * len(args.initializer_token) 475 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
462 476
463 if len(args.placeholder_token) != len(args.num_vectors): 477 if len(args.placeholder_tokens) != len(args.num_vectors):
464 raise ValueError("--placeholder_token and --num_vectors must have the same number of items") 478 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
465 479
466 if isinstance(args.collection, str): 480 if isinstance(args.collection, str):
467 args.collection = [args.collection] 481 args.collection = [args.collection]
@@ -475,13 +489,197 @@ def parse_args():
475 return args 489 return args
476 490
477 491
492class Checkpointer(CheckpointerBase):
493 def __init__(
494 self,
495 weight_dtype,
496 accelerator: Accelerator,
497 vae: AutoencoderKL,
498 unet: UNet2DConditionModel,
499 tokenizer: MultiCLIPTokenizer,
500 text_encoder: CLIPTextModel,
501 ema_embeddings: EMAModel,
502 scheduler,
503 placeholder_tokens,
504 placeholder_token_ids,
505 *args,
506 **kwargs
507 ):
508 super().__init__(*args, **kwargs)
509
510 self.weight_dtype = weight_dtype
511 self.accelerator = accelerator
512 self.vae = vae
513 self.unet = unet
514 self.tokenizer = tokenizer
515 self.text_encoder = text_encoder
516 self.ema_embeddings = ema_embeddings
517 self.scheduler = scheduler
518 self.placeholder_tokens = placeholder_tokens
519 self.placeholder_token_ids = placeholder_token_ids
520
521 @torch.no_grad()
522 def checkpoint(self, step, postfix):
523 print("Saving checkpoint for step %d..." % step)
524
525 checkpoints_path = self.output_dir.joinpath("checkpoints")
526 checkpoints_path.mkdir(parents=True, exist_ok=True)
527
528 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
529
530 ema_context = self.ema_embeddings.apply_temporary(
531 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
532
533 with ema_context:
534 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
535 text_encoder.text_model.embeddings.save_embed(
536 ids,
537 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
538 )
539
540 del text_encoder
541
542 @torch.no_grad()
543 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
544 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
545
546 ema_context = self.ema_embeddings.apply_temporary(
547 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
548
549 with ema_context:
550 orig_dtype = text_encoder.dtype
551 text_encoder.to(dtype=self.weight_dtype)
552
553 pipeline = VlpnStableDiffusion(
554 text_encoder=text_encoder,
555 vae=self.vae,
556 unet=self.unet,
557 tokenizer=self.tokenizer,
558 scheduler=self.scheduler,
559 ).to(self.accelerator.device)
560 pipeline.set_progress_bar_config(dynamic_ncols=True)
561
562 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
563
564 text_encoder.to(dtype=orig_dtype)
565
566 del text_encoder
567 del pipeline
568
569 if torch.cuda.is_available():
570 torch.cuda.empty_cache()
571
572
478def main(): 573def main():
479 args = parse_args() 574 args = parse_args()
480 575
481 def data_filter(item: VlpnDataItem): 576 global_step_offset = args.global_step
577 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
578 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
579 basepath.mkdir(parents=True, exist_ok=True)
580
581 accelerator = Accelerator(
582 log_with=LoggerType.TENSORBOARD,
583 logging_dir=f"{basepath}",
584 gradient_accumulation_steps=args.gradient_accumulation_steps,
585 mixed_precision=args.mixed_precision
586 )
587
588 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
589
590 args.seed = args.seed or (torch.random.seed() >> 32)
591 set_seed(args.seed)
592
593 save_args(basepath, args)
594
595 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
596 args.pretrained_model_name_or_path)
597
598 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
599 tokenizer.set_dropout(args.vector_dropout)
600
601 vae.enable_slicing()
602 vae.set_use_memory_efficient_attention_xformers(True)
603 unet.set_use_memory_efficient_attention_xformers(True)
604
605 if args.gradient_checkpointing:
606 unet.enable_gradient_checkpointing()
607 text_encoder.gradient_checkpointing_enable()
608
609 if args.embeddings_dir is not None:
610 embeddings_dir = Path(args.embeddings_dir)
611 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
612 raise ValueError("--embeddings_dir must point to an existing directory")
613
614 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
615 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
616
617 placeholder_token_ids = add_placeholder_tokens(
618 tokenizer=tokenizer,
619 embeddings=embeddings,
620 placeholder_tokens=args.placeholder_tokens,
621 initializer_tokens=args.initializer_tokens,
622 num_vectors=args.num_vectors
623 )
624
625 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}")
626
627 if args.use_ema:
628 ema_embeddings = EMAModel(
629 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
630 inv_gamma=args.ema_inv_gamma,
631 power=args.ema_power,
632 max_value=args.ema_max_decay,
633 )
634 else:
635 ema_embeddings = None
636
637 vae.requires_grad_(False)
638 unet.requires_grad_(False)
639
640 text_encoder.text_model.encoder.requires_grad_(False)
641 text_encoder.text_model.final_layer_norm.requires_grad_(False)
642 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
643 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
644
645 if args.scale_lr:
646 args.learning_rate = (
647 args.learning_rate * args.gradient_accumulation_steps *
648 args.train_batch_size * accelerator.num_processes
649 )
650
651 if args.find_lr:
652 args.learning_rate = 1e-5
653
654 if args.use_8bit_adam:
655 try:
656 import bitsandbytes as bnb
657 except ImportError:
658 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
659
660 optimizer_class = bnb.optim.AdamW8bit
661 else:
662 optimizer_class = torch.optim.AdamW
663
664 optimizer = optimizer_class(
665 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
666 lr=args.learning_rate,
667 betas=(args.adam_beta1, args.adam_beta2),
668 weight_decay=args.adam_weight_decay,
669 eps=args.adam_epsilon,
670 amsgrad=args.adam_amsgrad,
671 )
672
673 weight_dtype = torch.float32
674 if args.mixed_precision == "fp16":
675 weight_dtype = torch.float16
676 elif args.mixed_precision == "bf16":
677 weight_dtype = torch.bfloat16
678
679 def keyword_filter(item: VlpnDataItem):
482 cond1 = any( 680 cond1 = any(
483 keyword in part 681 keyword in part
484 for keyword in args.placeholder_token 682 for keyword in args.placeholder_tokens
485 for part in item.prompt 683 for part in item.prompt
486 ) 684 )
487 cond3 = args.collection is None or args.collection in item.collection 685 cond3 = args.collection is None or args.collection in item.collection
@@ -491,78 +689,185 @@ def main():
491 ) 689 )
492 return cond1 and cond3 and cond4 690 return cond1 and cond3 and cond4
493 691
494 setup = train_setup( 692 datamodule = VlpnDataModule(
495 output_dir=args.output_dir,
496 project=args.project,
497 pretrained_model_name_or_path=args.pretrained_model_name_or_path,
498 learning_rate=args.learning_rate,
499 data_file=args.train_data_file, 693 data_file=args.train_data_file,
500 gradient_accumulation_steps=args.gradient_accumulation_steps, 694 batch_size=args.train_batch_size,
501 mixed_precision=args.mixed_precision, 695 tokenizer=tokenizer,
502 seed=args.seed, 696 class_subdir=args.class_image_dir,
503 vector_shuffle=args.vector_shuffle,
504 vector_dropout=args.vector_dropout,
505 gradient_checkpointing=args.gradient_checkpointing,
506 embeddings_dir=args.embeddings_dir,
507 placeholder_token=args.placeholder_token,
508 initializer_token=args.initializer_token,
509 num_vectors=args.num_vectors,
510 scale_lr=args.scale_lr,
511 use_8bit_adam=args.use_8bit_adam,
512 train_batch_size=args.train_batch_size,
513 class_image_dir=args.class_image_dir,
514 num_class_images=args.num_class_images, 697 num_class_images=args.num_class_images,
515 resolution=args.resolution, 698 size=args.resolution,
516 num_buckets=args.num_buckets, 699 num_buckets=args.num_buckets,
517 progressive_buckets=args.progressive_buckets, 700 progressive_buckets=args.progressive_buckets,
518 bucket_step_size=args.bucket_step_size, 701 bucket_step_size=args.bucket_step_size,
519 bucket_max_pixels=args.bucket_max_pixels, 702 bucket_max_pixels=args.bucket_max_pixels,
520 tag_dropout=args.tag_dropout, 703 dropout=args.tag_dropout,
521 tag_shuffle=not args.no_tag_shuffle, 704 shuffle=not args.no_tag_shuffle,
522 data_template=args.train_data_template, 705 template_key=args.train_data_template,
523 valid_set_size=args.valid_set_size, 706 valid_set_size=args.valid_set_size,
524 valid_set_repeat=args.valid_set_repeat, 707 valid_set_repeat=args.valid_set_repeat,
525 data_filter=data_filter, 708 num_workers=args.dataloader_num_workers,
526 sample_image_size=args.sample_image_size, 709 seed=args.seed,
527 sample_batch_size=args.sample_batch_size, 710 filter=keyword_filter,
528 sample_steps=args.sample_steps, 711 dtype=weight_dtype
712 )
713 datamodule.setup()
714
715 train_dataloader = datamodule.train_dataloader
716 val_dataloader = datamodule.val_dataloader
717
718 if args.num_class_images != 0:
719 generate_class_images(
720 accelerator,
721 text_encoder,
722 vae,
723 unet,
724 tokenizer,
725 sample_scheduler,
726 datamodule.data_train,
727 args.sample_batch_size,
728 args.sample_image_size,
729 args.sample_steps
730 )
731
732 if args.find_lr:
733 lr_scheduler = None
734 else:
735 lr_scheduler = get_scheduler(
736 args.lr_scheduler,
737 optimizer=optimizer,
738 num_training_steps_per_epoch=len(train_dataloader),
739 gradient_accumulation_steps=args.gradient_accumulation_steps,
740 min_lr=args.lr_min_lr,
741 warmup_func=args.lr_warmup_func,
742 annealing_func=args.lr_annealing_func,
743 warmup_exp=args.lr_warmup_exp,
744 annealing_exp=args.lr_annealing_exp,
745 cycles=args.lr_cycles,
746 train_epochs=args.num_train_epochs,
747 warmup_epochs=args.lr_warmup_epochs,
748 )
749
750 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
751 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
529 ) 752 )
530 753
531 save_args(setup.output_dir, args) 754 vae.to(accelerator.device, dtype=weight_dtype)
755 unet.to(accelerator.device, dtype=weight_dtype)
532 756
533 train_ti( 757 if args.use_ema:
534 setup=setup, 758 ema_embeddings.to(accelerator.device)
535 num_train_epochs=args.num_train_epochs, 759
536 num_class_images=args.num_class_images, 760 if args.gradient_checkpointing:
537 prior_loss_weight=args.prior_loss_weight, 761 unet.train()
538 use_ema=args.use_ema, 762 else:
539 ema_inv_gamma=args.ema_inv_gamma, 763 unet.eval()
540 ema_power=args.ema_power, 764
541 ema_max_decay=args.ema_max_decay, 765 @contextmanager
542 adam_beta1=args.adam_beta1, 766 def on_train(epoch: int):
543 adam_beta2=args.adam_beta2, 767 try:
544 adam_weight_decay=args.adam_weight_decay, 768 tokenizer.train()
545 adam_epsilon=args.adam_epsilon, 769 yield
546 adam_amsgrad=args.adam_amsgrad, 770 finally:
547 lr_scheduler=args.lr_scheduler, 771 pass
548 lr_min_lr=args.lr_min_lr, 772
549 lr_warmup_func=args.lr_warmup_func, 773 @contextmanager
550 lr_annealing_func=args.lr_annealing_func, 774 def on_eval():
551 lr_warmup_exp=args.lr_warmup_exp, 775 try:
552 lr_annealing_exp=args.lr_annealing_exp, 776 tokenizer.eval()
553 lr_cycles=args.lr_cycles, 777
554 lr_warmup_epochs=args.lr_warmup_epochs, 778 ema_context = ema_embeddings.apply_temporary(
555 emb_decay_target=args.emb_decay_target, 779 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext()
556 emb_decay_factor=args.emb_decay_factor, 780
557 emb_decay_start=args.emb_decay_start, 781 with ema_context:
782 yield
783 finally:
784 pass
785
786 @torch.no_grad()
787 def on_after_optimize(lr: float):
788 text_encoder.text_model.embeddings.normalize(
789 args.emb_decay_target,
790 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start))))
791 )
792
793 if args.use_ema:
794 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
795
796 def on_log():
797 if args.use_ema:
798 return {"ema_decay": ema_embeddings.decay}
799 return {}
800
801 loss_step_ = partial(
802 loss_step,
803 vae,
804 noise_scheduler,
805 unet,
806 text_encoder,
807 args.num_class_images != 0,
808 args.prior_loss_weight,
809 args.seed,
810 )
811
812 checkpointer = Checkpointer(
813 weight_dtype=weight_dtype,
814 train_dataloader=train_dataloader,
815 val_dataloader=val_dataloader,
816 accelerator=accelerator,
817 vae=vae,
818 unet=unet,
819 tokenizer=tokenizer,
820 text_encoder=text_encoder,
821 ema_embeddings=ema_embeddings,
822 scheduler=sample_scheduler,
823 placeholder_tokens=args.placeholder_tokens,
824 placeholder_token_ids=placeholder_token_ids,
825 output_dir=basepath,
558 sample_image_size=args.sample_image_size, 826 sample_image_size=args.sample_image_size,
559 sample_batch_size=args.sample_batch_size, 827 sample_batch_size=args.sample_batch_size,
560 sample_batches=args.sample_batches, 828 sample_batches=args.sample_batches,
561 sample_frequency=args.sample_frequency, 829 seed=args.seed
562 sample_steps=args.sample_steps, 830 )
563 checkpoint_frequency=args.checkpoint_frequency, 831
564 global_step_offset=args.global_step, 832 if accelerator.is_main_process:
565 ) 833 accelerator.init_trackers("textual_inversion")
834
835 if args.find_lr:
836 lr_finder = LRFinder(
837 accelerator=accelerator,
838 optimizer=optimizer,
839 model=text_encoder,
840 train_dataloader=train_dataloader,
841 val_dataloader=val_dataloader,
842 loss_step=loss_step_,
843 on_train=on_train,
844 on_eval=on_eval,
845 on_after_optimize=on_after_optimize,
846 )
847 lr_finder.run(num_epochs=100, end_lr=1e3)
848
849 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
850 plt.close()
851 else:
852 train_loop(
853 accelerator=accelerator,
854 optimizer=optimizer,
855 lr_scheduler=lr_scheduler,
856 model=text_encoder,
857 checkpointer=checkpointer,
858 train_dataloader=train_dataloader,
859 val_dataloader=val_dataloader,
860 loss_step=loss_step_,
861 sample_frequency=args.sample_frequency,
862 sample_steps=args.sample_steps,
863 checkpoint_frequency=args.checkpoint_frequency,
864 global_step_offset=global_step_offset,
865 num_epochs=args.num_train_epochs,
866 on_log=on_log,
867 on_train=on_train,
868 on_after_optimize=on_after_optimize,
869 on_eval=on_eval
870 )
566 871
567 872
568if __name__ == "__main__": 873if __name__ == "__main__":
diff --git a/training/common.py b/training/common.py
index 73ce814..b6964a3 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,52 +1,24 @@
1import math 1import math
2from pathlib import Path
3from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple 3from typing import Callable, Any, Tuple, Union
5import datetime
6import logging
7 4
8import torch 5import torch
9import torch.nn.functional as F 6import torch.nn.functional as F
10from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
11 8
12from accelerate import Accelerator 9from accelerate import Accelerator
13from accelerate.utils import LoggerType, set_seed
14from transformers import CLIPTextModel 10from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
16from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
17 12
18from tqdm.auto import tqdm 13from tqdm.auto import tqdm
19from slugify import slugify
20 14
21from data.csv import VlpnDataModule, VlpnDataItem
22from util import load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from models.clip.embeddings import patch_managed_embeddings 16from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
25from models.clip.util import get_extended_embeddings 17from models.clip.util import get_extended_embeddings
26from models.clip.tokenizer import MultiCLIPTokenizer 18from models.clip.tokenizer import MultiCLIPTokenizer
27from training.optimization import get_one_cycle_schedule
28from training.util import AverageMeter, CheckpointerBase 19from training.util import AverageMeter, CheckpointerBase
29 20
30 21
31class TrainingSetup(NamedTuple):
32 accelerator: Accelerator
33 tokenizer: MultiCLIPTokenizer
34 text_encoder: CLIPTextModel
35 vae: AutoencoderKL
36 unet: UNet2DConditionModel
37 noise_scheduler: DDPMScheduler
38 checkpoint_scheduler: DPMSolverMultistepScheduler
39 optimizer_class: Callable
40 learning_rate: float
41 weight_dtype: torch.dtype
42 output_dir: Path
43 seed: int
44 train_dataloader: DataLoader
45 val_dataloader: DataLoader
46 placeholder_token: list[str]
47 placeholder_token_ids: list[list[int]]
48
49
50def noop(*args, **kwards): 22def noop(*args, **kwards):
51 pass 23 pass
52 24
@@ -59,57 +31,6 @@ def noop_on_log():
59 return {} 31 return {}
60 32
61 33
62def get_scheduler(
63 id: str,
64 optimizer: torch.optim.Optimizer,
65 num_training_steps_per_epoch: int,
66 gradient_accumulation_steps: int,
67 min_lr: float = 0.04,
68 warmup_func: str = "cos",
69 annealing_func: str = "cos",
70 warmup_exp: int = 1,
71 annealing_exp: int = 1,
72 cycles: int = 1,
73 train_epochs: int = 100,
74 warmup_epochs: int = 10,
75):
76 num_training_steps_per_epoch = math.ceil(
77 num_training_steps_per_epoch / gradient_accumulation_steps
78 ) * gradient_accumulation_steps
79 num_training_steps = train_epochs * num_training_steps_per_epoch
80 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
81
82 if id == "one_cycle":
83 lr_scheduler = get_one_cycle_schedule(
84 optimizer=optimizer,
85 num_training_steps=num_training_steps,
86 warmup=warmup_func,
87 annealing=annealing_func,
88 warmup_exp=warmup_exp,
89 annealing_exp=annealing_exp,
90 min_lr=min_lr,
91 )
92 elif id == "cosine_with_restarts":
93 if cycles is None:
94 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
95
96 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
97 optimizer=optimizer,
98 num_warmup_steps=num_warmup_steps,
99 num_training_steps=num_training_steps,
100 num_cycles=cycles,
101 )
102 else:
103 lr_scheduler = get_scheduler_(
104 id,
105 optimizer=optimizer,
106 num_warmup_steps=num_warmup_steps,
107 num_training_steps=num_training_steps,
108 )
109
110 return lr_scheduler
111
112
113def generate_class_images( 34def generate_class_images(
114 accelerator, 35 accelerator,
115 text_encoder, 36 text_encoder,
@@ -162,194 +83,43 @@ def generate_class_images(
162 torch.cuda.empty_cache() 83 torch.cuda.empty_cache()
163 84
164 85
165def train_setup( 86def get_models(pretrained_model_name_or_path: str):
166 output_dir: str,
167 project: str,
168 pretrained_model_name_or_path: str,
169 learning_rate: float,
170 data_file: str,
171 gradient_accumulation_steps: int = 1,
172 mixed_precision: Literal["no", "fp16", "bf16"] = "no",
173 seed: Optional[int] = None,
174 vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto",
175 vector_dropout: float = 0.1,
176 gradient_checkpointing: bool = True,
177 embeddings_dir: Optional[str] = None,
178 placeholder_token: list[str] = [],
179 initializer_token: list[str] = [],
180 num_vectors: int = 1,
181 scale_lr: bool = False,
182 use_8bit_adam: bool = False,
183 train_batch_size: int = 1,
184 class_image_dir: Optional[str] = None,
185 num_class_images: int = 0,
186 resolution: int = 768,
187 num_buckets: int = 0,
188 progressive_buckets: bool = False,
189 bucket_step_size: int = 64,
190 bucket_max_pixels: Optional[int] = None,
191 tag_dropout: float = 0.1,
192 tag_shuffle: bool = True,
193 data_template: str = "template",
194 valid_set_size: Optional[int] = None,
195 valid_set_repeat: int = 1,
196 data_filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 sample_batch_size: int = 1,
198 sample_image_size: int = 768,
199 sample_steps: int = 20,
200) -> TrainingSetup:
201 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
202 output_dir = Path(output_dir).joinpath(slugify(project), now)
203 output_dir.mkdir(parents=True, exist_ok=True)
204
205 accelerator = Accelerator(
206 log_with=LoggerType.TENSORBOARD,
207 logging_dir=f"{output_dir}",
208 gradient_accumulation_steps=gradient_accumulation_steps,
209 mixed_precision=mixed_precision
210 )
211
212 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
213
214 seed = seed or (torch.random.seed() >> 32)
215 set_seed(seed)
216
217 # Load the tokenizer and add the placeholder token as a additional special token
218 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 87 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
219 tokenizer.set_use_vector_shuffle(vector_shuffle)
220 tokenizer.set_dropout(vector_dropout)
221
222 # Load models and create wrapper for stable diffusion
223 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 88 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
224 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 89 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
225 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 90 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
226 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 91 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
227 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 92 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
228 pretrained_model_name_or_path, subfolder='scheduler') 93 pretrained_model_name_or_path, subfolder='scheduler')
229 94
230 vae.enable_slicing() 95 vae.enable_slicing()
231 vae.set_use_memory_efficient_attention_xformers(True) 96 vae.set_use_memory_efficient_attention_xformers(True)
232 unet.set_use_memory_efficient_attention_xformers(True) 97 unet.set_use_memory_efficient_attention_xformers(True)
233 98
234 if gradient_checkpointing:
235 unet.enable_gradient_checkpointing()
236 text_encoder.gradient_checkpointing_enable()
237
238 embeddings = patch_managed_embeddings(text_encoder) 99 embeddings = patch_managed_embeddings(text_encoder)
239 100
240 if embeddings_dir is not None: 101 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
241 embeddings_dir = Path(embeddings_dir)
242 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
243 raise ValueError("--embeddings_dir must point to an existing directory")
244 102
245 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
246 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
247 103
248 # Convert the initializer_token, placeholder_token to ids 104def add_placeholder_tokens(
105 tokenizer: MultiCLIPTokenizer,
106 embeddings: ManagedCLIPTextEmbeddings,
107 placeholder_tokens: list[str],
108 initializer_tokens: list[str],
109 num_vectors: Union[list[int], int]
110):
249 initializer_token_ids = [ 111 initializer_token_ids = [
250 tokenizer.encode(token, add_special_tokens=False) 112 tokenizer.encode(token, add_special_tokens=False)
251 for token in initializer_token 113 for token in initializer_tokens
252 ] 114 ]
115 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
253 116
254 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors)
255 embeddings.resize(len(tokenizer)) 117 embeddings.resize(len(tokenizer))
256 118
257 for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): 119 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
258 embeddings.add_embed(new_id, init_ids) 120 embeddings.add_embed(placeholder_token_id, initializer_token_id)
259
260 init_ratios = [
261 f"{len(init_ids)} / {len(new_id)}"
262 for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids)
263 ]
264
265 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}")
266 121
267 vae.requires_grad_(False) 122 return placeholder_token_ids
268 unet.requires_grad_(False)
269 text_encoder.requires_grad_(False)
270
271 if scale_lr:
272 learning_rate = (
273 learning_rate * gradient_accumulation_steps *
274 train_batch_size * accelerator.num_processes
275 )
276
277 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
278 if use_8bit_adam:
279 try:
280 import bitsandbytes as bnb
281 except ImportError:
282 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
283
284 optimizer_class = bnb.optim.AdamW8bit
285 else:
286 optimizer_class = torch.optim.AdamW
287
288 weight_dtype = torch.float32
289 if mixed_precision == "fp16":
290 weight_dtype = torch.float16
291 elif mixed_precision == "bf16":
292 weight_dtype = torch.bfloat16
293
294 datamodule = VlpnDataModule(
295 data_file=data_file,
296 batch_size=train_batch_size,
297 tokenizer=tokenizer,
298 class_subdir=class_image_dir,
299 num_class_images=num_class_images,
300 size=resolution,
301 num_buckets=num_buckets,
302 progressive_buckets=progressive_buckets,
303 bucket_step_size=bucket_step_size,
304 bucket_max_pixels=bucket_max_pixels,
305 dropout=tag_dropout,
306 shuffle=tag_shuffle,
307 template_key=data_template,
308 valid_set_size=valid_set_size,
309 valid_set_repeat=valid_set_repeat,
310 seed=seed,
311 filter=data_filter,
312 dtype=weight_dtype
313 )
314 datamodule.setup()
315
316 train_dataloader = datamodule.train_dataloader
317 val_dataloader = datamodule.val_dataloader
318
319 train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader)
320
321 if num_class_images != 0:
322 generate_class_images(
323 accelerator,
324 text_encoder,
325 vae,
326 unet,
327 tokenizer,
328 checkpoint_scheduler,
329 datamodule.data_train,
330 sample_batch_size,
331 sample_image_size,
332 sample_steps
333 )
334
335 return TrainingSetup(
336 accelerator=accelerator,
337 tokenizer=tokenizer,
338 text_encoder=text_encoder,
339 vae=vae,
340 unet=unet,
341 noise_scheduler=noise_scheduler,
342 checkpoint_scheduler=checkpoint_scheduler,
343 optimizer_class=optimizer_class,
344 learning_rate=learning_rate,
345 output_dir=output_dir,
346 weight_dtype=weight_dtype,
347 seed=seed,
348 train_dataloader=train_dataloader,
349 val_dataloader=val_dataloader,
350 placeholder_token=placeholder_token,
351 placeholder_token_ids=placeholder_token_ids
352 )
353 123
354 124
355def loss_step( 125def loss_step(
diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py
deleted file mode 100644
index e69de29..0000000
--- a/training/modules/dreambooth.py
+++ /dev/null
diff --git a/training/modules/lora.py b/training/modules/lora.py
deleted file mode 100644
index e69de29..0000000
--- a/training/modules/lora.py
+++ /dev/null
diff --git a/training/modules/ti.py b/training/modules/ti.py
deleted file mode 100644
index 2db6f88..0000000
--- a/training/modules/ti.py
+++ /dev/null
@@ -1,284 +0,0 @@
1from typing import Literal
2from functools import partial
3from contextlib import contextmanager, nullcontext
4
5import torch
6
7from slugify import slugify
8
9from accelerate import Accelerator
10from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel
12
13from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
14from models.clip.tokenizer import MultiCLIPTokenizer
15
16from training.common import TrainingSetup, get_scheduler, train_loop, loss_step
17from training.util import EMAModel, CheckpointerBase
18
19
20class Checkpointer(CheckpointerBase):
21 def __init__(
22 self,
23 accelerator: Accelerator,
24 vae: AutoencoderKL,
25 unet: UNet2DConditionModel,
26 tokenizer: MultiCLIPTokenizer,
27 text_encoder: CLIPTextModel,
28 ema_embeddings: EMAModel,
29 weight_dtype: torch.dtype,
30 scheduler,
31 placeholder_token,
32 placeholder_token_ids,
33 *args,
34 **kwargs
35 ):
36 super().__init__(*args, **kwargs)
37
38 self.weight_dtype = weight_dtype
39 self.accelerator = accelerator
40 self.vae = vae
41 self.unet = unet
42 self.tokenizer = tokenizer
43 self.text_encoder = text_encoder
44 self.ema_embeddings = ema_embeddings
45 self.scheduler = scheduler
46 self.placeholder_token = placeholder_token
47 self.placeholder_token_ids = placeholder_token_ids
48
49 @torch.no_grad()
50 def checkpoint(self, step, postfix):
51 print("Saving checkpoint for step %d..." % step)
52
53 checkpoints_path = self.output_dir.joinpath("checkpoints")
54 checkpoints_path.mkdir(parents=True, exist_ok=True)
55
56 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
57
58 ema_context = nullcontext()
59 if self.ema_embeddings is not None:
60 ema_context = self.ema_embeddings.apply_temporary(
61 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
62
63 with ema_context:
64 for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids):
65 text_encoder.text_model.embeddings.save_embed(
66 ids,
67 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
68 )
69
70 del text_encoder
71
72 @torch.no_grad()
73 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
74 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
75
76 ema_context = nullcontext()
77 if self.ema_embeddings is not None:
78 ema_context = self.ema_embeddings.apply_temporary(
79 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
80
81 with ema_context:
82 orig_dtype = text_encoder.dtype
83 text_encoder.to(dtype=self.weight_dtype)
84
85 pipeline = VlpnStableDiffusion(
86 text_encoder=text_encoder,
87 vae=self.vae,
88 unet=self.unet,
89 tokenizer=self.tokenizer,
90 scheduler=self.scheduler,
91 ).to(self.accelerator.device)
92 pipeline.set_progress_bar_config(dynamic_ncols=True)
93
94 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
95
96 text_encoder.to(dtype=orig_dtype)
97
98 del text_encoder
99 del pipeline
100
101 if torch.cuda.is_available():
102 torch.cuda.empty_cache()
103
104
105def train_ti(
106 setup: TrainingSetup,
107 num_train_epochs: int = 100,
108 num_class_images: int = 0,
109 prior_loss_weight: float = 1.0,
110 use_ema: bool = False,
111 ema_inv_gamma: float = 1.0,
112 ema_power: float = 4/5,
113 ema_max_decay: float = .9999,
114 adam_beta1: float = 0.9,
115 adam_beta2: float = 0.999,
116 adam_weight_decay: float = 0,
117 adam_epsilon: float = 1e-08,
118 adam_amsgrad: bool = False,
119 lr_scheduler: Literal[
120 "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle"
121 ] = "one_cycle",
122 lr_min_lr: float = 0.04,
123 lr_warmup_func: Literal["linear", "cos"] = "cos",
124 lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos",
125 lr_warmup_exp: int = 1,
126 lr_annealing_exp: int = 1,
127 lr_cycles: int = 1,
128 lr_warmup_epochs: int = 10,
129 emb_decay_target: float = 0.4,
130 emb_decay_factor: float = 1,
131 emb_decay_start: float = 1e-4,
132 sample_image_size: int = 768,
133 sample_batch_size: int = 1,
134 sample_batches: int = 1,
135 sample_frequency: int = 10,
136 sample_steps: int = 20,
137 checkpoint_frequency: int = 50,
138 global_step_offset: int = 0,
139):
140 if use_ema:
141 ema_embeddings = EMAModel(
142 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
143 inv_gamma=ema_inv_gamma,
144 power=ema_power,
145 max_value=ema_max_decay,
146 )
147 else:
148 ema_embeddings = None
149
150 setup.text_encoder.requires_grad_(True)
151 setup.text_encoder.text_model.encoder.requires_grad_(False)
152 setup.text_encoder.text_model.final_layer_norm.requires_grad_(False)
153 setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
154 setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
155
156 # Initialize the optimizer
157 optimizer = setup.optimizer_class(
158 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
159 lr=setup.learning_rate,
160 betas=(adam_beta1, adam_beta2),
161 weight_decay=adam_weight_decay,
162 eps=adam_epsilon,
163 amsgrad=adam_amsgrad,
164 )
165
166 lr_scheduler = get_scheduler(
167 lr_scheduler,
168 optimizer=optimizer,
169 min_lr=lr_min_lr,
170 warmup_func=lr_warmup_func,
171 annealing_func=lr_annealing_func,
172 warmup_exp=lr_warmup_exp,
173 annealing_exp=lr_annealing_exp,
174 cycles=lr_cycles,
175 train_epochs=num_train_epochs,
176 warmup_epochs=lr_warmup_epochs,
177 num_training_steps_per_epoch=len(setup.train_dataloader),
178 gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps
179 )
180
181 text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare(
182 setup.text_encoder, optimizer, lr_scheduler
183 )
184
185 # Move vae and unet to device
186 setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype)
187 setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype)
188
189 if use_ema:
190 ema_embeddings.to(setup.accelerator.device)
191
192 setup.unet.train()
193
194 @contextmanager
195 def on_train(epoch: int):
196 try:
197 setup.tokenizer.train()
198 yield
199 finally:
200 pass
201
202 @contextmanager
203 def on_eval():
204 try:
205 setup.tokenizer.eval()
206
207 ema_context = nullcontext()
208 if use_ema:
209 ema_context = ema_embeddings.apply_temporary(
210 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
211
212 with ema_context:
213 yield
214 finally:
215 pass
216
217 @torch.no_grad()
218 def on_after_optimize(lr: float):
219 text_encoder.text_model.embeddings.normalize(
220 emb_decay_target,
221 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start))))
222 )
223
224 if use_ema:
225 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
226
227 def on_log():
228 if use_ema:
229 return {"ema_decay": ema_embeddings.decay}
230 return {}
231
232 loss_step_ = partial(
233 loss_step,
234 setup.vae,
235 setup.noise_scheduler,
236 setup.unet,
237 text_encoder,
238 num_class_images != 0,
239 prior_loss_weight,
240 setup.seed,
241 )
242
243 checkpointer = Checkpointer(
244 accelerator=setup.accelerator,
245 vae=setup.vae,
246 unet=setup.unet,
247 tokenizer=setup.tokenizer,
248 text_encoder=text_encoder,
249 ema_embeddings=ema_embeddings,
250 weight_dtype=setup.weight_dtype,
251 scheduler=setup.checkpoint_scheduler,
252 placeholder_token=setup.placeholder_token,
253 placeholder_token_ids=setup.placeholder_token_ids,
254 train_dataloader=setup.train_dataloader,
255 val_dataloader=setup.val_dataloader,
256 output_dir=setup.output_dir,
257 seed=setup.seed,
258 sample_image_size=sample_image_size,
259 sample_batch_size=sample_batch_size,
260 sample_batches=sample_batches
261 )
262
263 if setup.accelerator.is_main_process:
264 setup.accelerator.init_trackers("textual_inversion")
265
266 train_loop(
267 accelerator=setup.accelerator,
268 optimizer=optimizer,
269 lr_scheduler=lr_scheduler,
270 model=text_encoder,
271 checkpointer=checkpointer,
272 train_dataloader=setup.train_dataloader,
273 val_dataloader=setup.val_dataloader,
274 loss_step=loss_step_,
275 sample_frequency=sample_frequency,
276 sample_steps=sample_steps,
277 checkpoint_frequency=checkpoint_frequency,
278 global_step_offset=global_step_offset,
279 num_epochs=num_train_epochs,
280 on_log=on_log,
281 on_train=on_train,
282 on_after_optimize=on_after_optimize,
283 on_eval=on_eval
284 )
diff --git a/training/optimization.py b/training/optimization.py
index dd84f9c..5db7794 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -5,6 +5,8 @@ from functools import partial
5import torch 5import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9
8 10
9class OneCyclePhase(NamedTuple): 11class OneCyclePhase(NamedTuple):
10 step_min: int 12 step_min: int
@@ -83,3 +85,54 @@ def get_one_cycle_schedule(
83 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) 85 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min)
84 86
85 return LambdaLR(optimizer, lr_lambda, last_epoch) 87 return LambdaLR(optimizer, lr_lambda, last_epoch)
88
89
90def get_scheduler(
91 id: str,
92 optimizer: torch.optim.Optimizer,
93 num_training_steps_per_epoch: int,
94 gradient_accumulation_steps: int,
95 min_lr: float = 0.04,
96 warmup_func: str = "cos",
97 annealing_func: str = "cos",
98 warmup_exp: int = 1,
99 annealing_exp: int = 1,
100 cycles: int = 1,
101 train_epochs: int = 100,
102 warmup_epochs: int = 10,
103):
104 num_training_steps_per_epoch = math.ceil(
105 num_training_steps_per_epoch / gradient_accumulation_steps
106 ) * gradient_accumulation_steps
107 num_training_steps = train_epochs * num_training_steps_per_epoch
108 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
109
110 if id == "one_cycle":
111 lr_scheduler = get_one_cycle_schedule(
112 optimizer=optimizer,
113 num_training_steps=num_training_steps,
114 warmup=warmup_func,
115 annealing=annealing_func,
116 warmup_exp=warmup_exp,
117 annealing_exp=annealing_exp,
118 min_lr=min_lr,
119 )
120 elif id == "cosine_with_restarts":
121 if cycles is None:
122 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
123
124 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
125 optimizer=optimizer,
126 num_warmup_steps=num_warmup_steps,
127 num_training_steps=num_training_steps,
128 num_cycles=cycles,
129 )
130 else:
131 lr_scheduler = get_scheduler_(
132 id,
133 optimizer=optimizer,
134 num_warmup_steps=num_warmup_steps,
135 num_training_steps=num_training_steps,
136 )
137
138 return lr_scheduler