diff options
| -rw-r--r-- | train_dreambooth.py | 16 | ||||
| -rw-r--r-- | train_ti.py | 15 | ||||
| -rw-r--r-- | training/lr.py | 33 |
3 files changed, 55 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 05f6cb5..1e49474 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -565,8 +565,6 @@ def main(): | |||
| 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 566 | args.pretrained_model_name_or_path, subfolder='scheduler') | 566 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 567 | 567 | ||
| 568 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 569 | |||
| 570 | vae.enable_slicing() | 568 | vae.enable_slicing() |
| 571 | vae.set_use_memory_efficient_attention_xformers(True) | 569 | vae.set_use_memory_efficient_attention_xformers(True) |
| 572 | unet.set_use_memory_efficient_attention_xformers(True) | 570 | unet.set_use_memory_efficient_attention_xformers(True) |
| @@ -893,7 +891,16 @@ def main(): | |||
| 893 | accelerator.init_trackers("dreambooth", config=config) | 891 | accelerator.init_trackers("dreambooth", config=config) |
| 894 | 892 | ||
| 895 | if args.find_lr: | 893 | if args.find_lr: |
| 896 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 894 | lr_finder = LRFinder( |
| 895 | accelerator, | ||
| 896 | text_encoder, | ||
| 897 | optimizer, | ||
| 898 | train_dataloader, | ||
| 899 | val_dataloader, | ||
| 900 | loop, | ||
| 901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
| 902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
| 903 | ) | ||
| 897 | lr_finder.run(min_lr=1e-4) | 904 | lr_finder.run(min_lr=1e-4) |
| 898 | 905 | ||
| 899 | plt.savefig(basepath.joinpath("lr.png")) | 906 | plt.savefig(basepath.joinpath("lr.png")) |
| @@ -965,11 +972,11 @@ def main(): | |||
| 965 | local_progress_bar.reset() | 972 | local_progress_bar.reset() |
| 966 | 973 | ||
| 967 | unet.train() | 974 | unet.train() |
| 968 | |||
| 969 | if epoch < args.train_text_encoder_epochs: | 975 | if epoch < args.train_text_encoder_epochs: |
| 970 | text_encoder.train() | 976 | text_encoder.train() |
| 971 | elif epoch == args.train_text_encoder_epochs: | 977 | elif epoch == args.train_text_encoder_epochs: |
| 972 | text_encoder.requires_grad_(False) | 978 | text_encoder.requires_grad_(False) |
| 979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 973 | 980 | ||
| 974 | for step, batch in enumerate(train_dataloader): | 981 | for step, batch in enumerate(train_dataloader): |
| 975 | with accelerator.accumulate(unet): | 982 | with accelerator.accumulate(unet): |
| @@ -1023,6 +1030,7 @@ def main(): | |||
| 1023 | 1030 | ||
| 1024 | unet.eval() | 1031 | unet.eval() |
| 1025 | text_encoder.eval() | 1032 | text_encoder.eval() |
| 1033 | tokenizer.set_use_vector_shuffle(False) | ||
| 1026 | 1034 | ||
| 1027 | cur_loss_val = AverageMeter() | 1035 | cur_loss_val = AverageMeter() |
| 1028 | cur_acc_val = AverageMeter() | 1036 | cur_acc_val = AverageMeter() |
diff --git a/train_ti.py b/train_ti.py index 97dde1e..2b3f017 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -535,8 +535,6 @@ def main(): | |||
| 535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 536 | args.pretrained_model_name_or_path, subfolder='scheduler') | 536 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 537 | 537 | ||
| 538 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 539 | |||
| 540 | vae.enable_slicing() | 538 | vae.enable_slicing() |
| 541 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) |
| 542 | unet.set_use_memory_efficient_attention_xformers(True) | 540 | unet.set_use_memory_efficient_attention_xformers(True) |
| @@ -845,7 +843,16 @@ def main(): | |||
| 845 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) |
| 846 | 844 | ||
| 847 | if args.find_lr: | 845 | if args.find_lr: |
| 848 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 846 | lr_finder = LRFinder( |
| 847 | accelerator, | ||
| 848 | text_encoder, | ||
| 849 | optimizer, | ||
| 850 | train_dataloader, | ||
| 851 | val_dataloader, | ||
| 852 | loop, | ||
| 853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
| 854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
| 855 | ) | ||
| 849 | lr_finder.run(min_lr=1e-4) | 856 | lr_finder.run(min_lr=1e-4) |
| 850 | 857 | ||
| 851 | plt.savefig(basepath.joinpath("lr.png")) | 858 | plt.savefig(basepath.joinpath("lr.png")) |
| @@ -915,6 +922,7 @@ def main(): | |||
| 915 | local_progress_bar.reset() | 922 | local_progress_bar.reset() |
| 916 | 923 | ||
| 917 | text_encoder.train() | 924 | text_encoder.train() |
| 925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 918 | 926 | ||
| 919 | for step, batch in enumerate(train_dataloader): | 927 | for step, batch in enumerate(train_dataloader): |
| 920 | with accelerator.accumulate(text_encoder): | 928 | with accelerator.accumulate(text_encoder): |
| @@ -955,6 +963,7 @@ def main(): | |||
| 955 | accelerator.wait_for_everyone() | 963 | accelerator.wait_for_everyone() |
| 956 | 964 | ||
| 957 | text_encoder.eval() | 965 | text_encoder.eval() |
| 966 | tokenizer.set_use_vector_shuffle(False) | ||
| 958 | 967 | ||
| 959 | cur_loss_val = AverageMeter() | 968 | cur_loss_val = AverageMeter() |
| 960 | cur_acc_val = AverageMeter() | 969 | cur_acc_val = AverageMeter() |
diff --git a/training/lr.py b/training/lr.py index 3abd2f2..fe166ed 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import copy | 2 | import copy |
| 3 | from typing import Callable | ||
| 3 | 4 | ||
| 4 | import matplotlib.pyplot as plt | 5 | import matplotlib.pyplot as plt |
| 5 | import numpy as np | 6 | import numpy as np |
| @@ -10,19 +11,45 @@ from tqdm.auto import tqdm | |||
| 10 | from training.util import AverageMeter | 11 | from training.util import AverageMeter |
| 11 | 12 | ||
| 12 | 13 | ||
| 14 | def noop(): | ||
| 15 | pass | ||
| 16 | |||
| 17 | |||
| 13 | class LRFinder(): | 18 | class LRFinder(): |
| 14 | def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): | 19 | def __init__( |
| 20 | self, | ||
| 21 | accelerator, | ||
| 22 | model, | ||
| 23 | optimizer, | ||
| 24 | train_dataloader, | ||
| 25 | val_dataloader, | ||
| 26 | loss_fn, | ||
| 27 | on_train: Callable[[], None] = noop, | ||
| 28 | on_eval: Callable[[], None] = noop | ||
| 29 | ): | ||
| 15 | self.accelerator = accelerator | 30 | self.accelerator = accelerator |
| 16 | self.model = model | 31 | self.model = model |
| 17 | self.optimizer = optimizer | 32 | self.optimizer = optimizer |
| 18 | self.train_dataloader = train_dataloader | 33 | self.train_dataloader = train_dataloader |
| 19 | self.val_dataloader = val_dataloader | 34 | self.val_dataloader = val_dataloader |
| 20 | self.loss_fn = loss_fn | 35 | self.loss_fn = loss_fn |
| 36 | self.on_train = on_train | ||
| 37 | self.on_eval = on_eval | ||
| 21 | 38 | ||
| 22 | # self.model_state = copy.deepcopy(model.state_dict()) | 39 | # self.model_state = copy.deepcopy(model.state_dict()) |
| 23 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 40 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
| 24 | 41 | ||
| 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 42 | def run( |
| 43 | self, | ||
| 44 | min_lr, | ||
| 45 | skip_start: int = 10, | ||
| 46 | skip_end: int = 5, | ||
| 47 | num_epochs: int = 100, | ||
| 48 | num_train_batches: int = 1, | ||
| 49 | num_val_batches: int = math.inf, | ||
| 50 | smooth_f: float = 0.05, | ||
| 51 | diverge_th: int = 5 | ||
| 52 | ): | ||
| 26 | best_loss = None | 53 | best_loss = None |
| 27 | best_acc = None | 54 | best_acc = None |
| 28 | 55 | ||
| @@ -50,6 +77,7 @@ class LRFinder(): | |||
| 50 | avg_acc = AverageMeter() | 77 | avg_acc = AverageMeter() |
| 51 | 78 | ||
| 52 | self.model.train() | 79 | self.model.train() |
| 80 | self.on_train() | ||
| 53 | 81 | ||
| 54 | for step, batch in enumerate(self.train_dataloader): | 82 | for step, batch in enumerate(self.train_dataloader): |
| 55 | if step >= num_train_batches: | 83 | if step >= num_train_batches: |
| @@ -67,6 +95,7 @@ class LRFinder(): | |||
| 67 | progress_bar.update(1) | 95 | progress_bar.update(1) |
| 68 | 96 | ||
| 69 | self.model.eval() | 97 | self.model.eval() |
| 98 | self.on_eval() | ||
| 70 | 99 | ||
| 71 | with torch.inference_mode(): | 100 | with torch.inference_mode(): |
| 72 | for step, batch in enumerate(self.val_dataloader): | 101 | for step, batch in enumerate(self.val_dataloader): |
