diff options
-rw-r--r-- | train_dreambooth.py | 16 | ||||
-rw-r--r-- | train_ti.py | 15 | ||||
-rw-r--r-- | training/lr.py | 33 |
3 files changed, 55 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 05f6cb5..1e49474 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -565,8 +565,6 @@ def main(): | |||
565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
566 | args.pretrained_model_name_or_path, subfolder='scheduler') | 566 | args.pretrained_model_name_or_path, subfolder='scheduler') |
567 | 567 | ||
568 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
569 | |||
570 | vae.enable_slicing() | 568 | vae.enable_slicing() |
571 | vae.set_use_memory_efficient_attention_xformers(True) | 569 | vae.set_use_memory_efficient_attention_xformers(True) |
572 | unet.set_use_memory_efficient_attention_xformers(True) | 570 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -893,7 +891,16 @@ def main(): | |||
893 | accelerator.init_trackers("dreambooth", config=config) | 891 | accelerator.init_trackers("dreambooth", config=config) |
894 | 892 | ||
895 | if args.find_lr: | 893 | if args.find_lr: |
896 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 894 | lr_finder = LRFinder( |
895 | accelerator, | ||
896 | text_encoder, | ||
897 | optimizer, | ||
898 | train_dataloader, | ||
899 | val_dataloader, | ||
900 | loop, | ||
901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
903 | ) | ||
897 | lr_finder.run(min_lr=1e-4) | 904 | lr_finder.run(min_lr=1e-4) |
898 | 905 | ||
899 | plt.savefig(basepath.joinpath("lr.png")) | 906 | plt.savefig(basepath.joinpath("lr.png")) |
@@ -965,11 +972,11 @@ def main(): | |||
965 | local_progress_bar.reset() | 972 | local_progress_bar.reset() |
966 | 973 | ||
967 | unet.train() | 974 | unet.train() |
968 | |||
969 | if epoch < args.train_text_encoder_epochs: | 975 | if epoch < args.train_text_encoder_epochs: |
970 | text_encoder.train() | 976 | text_encoder.train() |
971 | elif epoch == args.train_text_encoder_epochs: | 977 | elif epoch == args.train_text_encoder_epochs: |
972 | text_encoder.requires_grad_(False) | 978 | text_encoder.requires_grad_(False) |
979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
973 | 980 | ||
974 | for step, batch in enumerate(train_dataloader): | 981 | for step, batch in enumerate(train_dataloader): |
975 | with accelerator.accumulate(unet): | 982 | with accelerator.accumulate(unet): |
@@ -1023,6 +1030,7 @@ def main(): | |||
1023 | 1030 | ||
1024 | unet.eval() | 1031 | unet.eval() |
1025 | text_encoder.eval() | 1032 | text_encoder.eval() |
1033 | tokenizer.set_use_vector_shuffle(False) | ||
1026 | 1034 | ||
1027 | cur_loss_val = AverageMeter() | 1035 | cur_loss_val = AverageMeter() |
1028 | cur_acc_val = AverageMeter() | 1036 | cur_acc_val = AverageMeter() |
diff --git a/train_ti.py b/train_ti.py index 97dde1e..2b3f017 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -535,8 +535,6 @@ def main(): | |||
535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
536 | args.pretrained_model_name_or_path, subfolder='scheduler') | 536 | args.pretrained_model_name_or_path, subfolder='scheduler') |
537 | 537 | ||
538 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
539 | |||
540 | vae.enable_slicing() | 538 | vae.enable_slicing() |
541 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) |
542 | unet.set_use_memory_efficient_attention_xformers(True) | 540 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -845,7 +843,16 @@ def main(): | |||
845 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) |
846 | 844 | ||
847 | if args.find_lr: | 845 | if args.find_lr: |
848 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 846 | lr_finder = LRFinder( |
847 | accelerator, | ||
848 | text_encoder, | ||
849 | optimizer, | ||
850 | train_dataloader, | ||
851 | val_dataloader, | ||
852 | loop, | ||
853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
855 | ) | ||
849 | lr_finder.run(min_lr=1e-4) | 856 | lr_finder.run(min_lr=1e-4) |
850 | 857 | ||
851 | plt.savefig(basepath.joinpath("lr.png")) | 858 | plt.savefig(basepath.joinpath("lr.png")) |
@@ -915,6 +922,7 @@ def main(): | |||
915 | local_progress_bar.reset() | 922 | local_progress_bar.reset() |
916 | 923 | ||
917 | text_encoder.train() | 924 | text_encoder.train() |
925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
918 | 926 | ||
919 | for step, batch in enumerate(train_dataloader): | 927 | for step, batch in enumerate(train_dataloader): |
920 | with accelerator.accumulate(text_encoder): | 928 | with accelerator.accumulate(text_encoder): |
@@ -955,6 +963,7 @@ def main(): | |||
955 | accelerator.wait_for_everyone() | 963 | accelerator.wait_for_everyone() |
956 | 964 | ||
957 | text_encoder.eval() | 965 | text_encoder.eval() |
966 | tokenizer.set_use_vector_shuffle(False) | ||
958 | 967 | ||
959 | cur_loss_val = AverageMeter() | 968 | cur_loss_val = AverageMeter() |
960 | cur_acc_val = AverageMeter() | 969 | cur_acc_val = AverageMeter() |
diff --git a/training/lr.py b/training/lr.py index 3abd2f2..fe166ed 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,5 +1,6 @@ | |||
1 | import math | 1 | import math |
2 | import copy | 2 | import copy |
3 | from typing import Callable | ||
3 | 4 | ||
4 | import matplotlib.pyplot as plt | 5 | import matplotlib.pyplot as plt |
5 | import numpy as np | 6 | import numpy as np |
@@ -10,19 +11,45 @@ from tqdm.auto import tqdm | |||
10 | from training.util import AverageMeter | 11 | from training.util import AverageMeter |
11 | 12 | ||
12 | 13 | ||
14 | def noop(): | ||
15 | pass | ||
16 | |||
17 | |||
13 | class LRFinder(): | 18 | class LRFinder(): |
14 | def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): | 19 | def __init__( |
20 | self, | ||
21 | accelerator, | ||
22 | model, | ||
23 | optimizer, | ||
24 | train_dataloader, | ||
25 | val_dataloader, | ||
26 | loss_fn, | ||
27 | on_train: Callable[[], None] = noop, | ||
28 | on_eval: Callable[[], None] = noop | ||
29 | ): | ||
15 | self.accelerator = accelerator | 30 | self.accelerator = accelerator |
16 | self.model = model | 31 | self.model = model |
17 | self.optimizer = optimizer | 32 | self.optimizer = optimizer |
18 | self.train_dataloader = train_dataloader | 33 | self.train_dataloader = train_dataloader |
19 | self.val_dataloader = val_dataloader | 34 | self.val_dataloader = val_dataloader |
20 | self.loss_fn = loss_fn | 35 | self.loss_fn = loss_fn |
36 | self.on_train = on_train | ||
37 | self.on_eval = on_eval | ||
21 | 38 | ||
22 | # self.model_state = copy.deepcopy(model.state_dict()) | 39 | # self.model_state = copy.deepcopy(model.state_dict()) |
23 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 40 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
24 | 41 | ||
25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 42 | def run( |
43 | self, | ||
44 | min_lr, | ||
45 | skip_start: int = 10, | ||
46 | skip_end: int = 5, | ||
47 | num_epochs: int = 100, | ||
48 | num_train_batches: int = 1, | ||
49 | num_val_batches: int = math.inf, | ||
50 | smooth_f: float = 0.05, | ||
51 | diverge_th: int = 5 | ||
52 | ): | ||
26 | best_loss = None | 53 | best_loss = None |
27 | best_acc = None | 54 | best_acc = None |
28 | 55 | ||
@@ -50,6 +77,7 @@ class LRFinder(): | |||
50 | avg_acc = AverageMeter() | 77 | avg_acc = AverageMeter() |
51 | 78 | ||
52 | self.model.train() | 79 | self.model.train() |
80 | self.on_train() | ||
53 | 81 | ||
54 | for step, batch in enumerate(self.train_dataloader): | 82 | for step, batch in enumerate(self.train_dataloader): |
55 | if step >= num_train_batches: | 83 | if step >= num_train_batches: |
@@ -67,6 +95,7 @@ class LRFinder(): | |||
67 | progress_bar.update(1) | 95 | progress_bar.update(1) |
68 | 96 | ||
69 | self.model.eval() | 97 | self.model.eval() |
98 | self.on_eval() | ||
70 | 99 | ||
71 | with torch.inference_mode(): | 100 | with torch.inference_mode(): |
72 | for step, batch in enumerate(self.val_dataloader): | 101 | for step, batch in enumerate(self.val_dataloader): |