diff options
-rw-r--r-- | train_dreambooth.py | 16 | ||||
-rw-r--r-- | train_lora.py | 10 | ||||
-rw-r--r-- | train_ti.py | 21 | ||||
-rw-r--r-- | training/lr.py | 46 | ||||
-rw-r--r-- | training/util.py | 5 |
5 files changed, 59 insertions, 39 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 325fe90..202d52c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | 30 | ||
31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -580,12 +580,10 @@ def main(): | |||
580 | 580 | ||
581 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 581 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
582 | 582 | ||
583 | freeze_params(itertools.chain( | 583 | text_encoder.text_model.encoder.requires_grad_(False) |
584 | text_encoder.text_model.encoder.parameters(), | 584 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
585 | text_encoder.text_model.final_layer_norm.parameters(), | 585 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
586 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 586 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
587 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
588 | )) | ||
589 | 587 | ||
590 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 588 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
591 | 589 | ||
@@ -905,9 +903,7 @@ def main(): | |||
905 | if epoch < args.train_text_encoder_epochs: | 903 | if epoch < args.train_text_encoder_epochs: |
906 | text_encoder.train() | 904 | text_encoder.train() |
907 | elif epoch == args.train_text_encoder_epochs: | 905 | elif epoch == args.train_text_encoder_epochs: |
908 | freeze_params(text_encoder.parameters()) | 906 | text_encoder.requires_grad_(False) |
909 | |||
910 | sample_checkpoint = False | ||
911 | 907 | ||
912 | for step, batch in enumerate(train_dataloader): | 908 | for step, batch in enumerate(train_dataloader): |
913 | with accelerator.accumulate(unet): | 909 | with accelerator.accumulate(unet): |
diff --git a/train_lora.py b/train_lora.py index ffca304..9a42cae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.lora import LoraAttnProcessor | 26 | from training.lora import LoraAttnProcessor |
27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | 30 | ||
31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -513,11 +513,9 @@ def main(): | |||
513 | 513 | ||
514 | print(f"Training added text embeddings") | 514 | print(f"Training added text embeddings") |
515 | 515 | ||
516 | freeze_params(itertools.chain( | 516 | text_encoder.text_model.encoder.requires_grad_(False) |
517 | text_encoder.text_model.encoder.parameters(), | 517 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
518 | text_encoder.text_model.final_layer_norm.parameters(), | 518 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
519 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
520 | )) | ||
521 | 519 | ||
522 | index_fixed_tokens = torch.arange(len(tokenizer)) | 520 | index_fixed_tokens = torch.arange(len(tokenizer)) |
523 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 521 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] |
diff --git a/train_ti.py b/train_ti.py index 870b2ba..d7696e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -25,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
25 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
26 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | 30 | ||
31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -533,12 +533,10 @@ def main(): | |||
533 | 533 | ||
534 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 534 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
535 | 535 | ||
536 | freeze_params(itertools.chain( | 536 | text_encoder.text_model.encoder.requires_grad_(False) |
537 | text_encoder.text_model.encoder.parameters(), | 537 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
538 | text_encoder.text_model.final_layer_norm.parameters(), | 538 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
539 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 539 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
540 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
541 | )) | ||
542 | 540 | ||
543 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 541 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
544 | 542 | ||
@@ -548,6 +546,9 @@ def main(): | |||
548 | args.train_batch_size * accelerator.num_processes | 546 | args.train_batch_size * accelerator.num_processes |
549 | ) | 547 | ) |
550 | 548 | ||
549 | if args.find_lr: | ||
550 | args.learning_rate = 1e2 | ||
551 | |||
551 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 552 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
552 | if args.use_8bit_adam: | 553 | if args.use_8bit_adam: |
553 | try: | 554 | try: |
@@ -715,7 +716,11 @@ def main(): | |||
715 | 716 | ||
716 | # Keep vae and unet in eval mode as we don't train these | 717 | # Keep vae and unet in eval mode as we don't train these |
717 | vae.eval() | 718 | vae.eval() |
718 | unet.eval() | 719 | |
720 | if args.gradient_checkpointing: | ||
721 | unet.train() | ||
722 | else: | ||
723 | unet.eval() | ||
719 | 724 | ||
720 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 725 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
721 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 726 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
diff --git a/training/lr.py b/training/lr.py index 8e558e1..c1fa3a0 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -22,10 +22,13 @@ class LRFinder(): | |||
22 | self.model_state = copy.deepcopy(model.state_dict()) | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
24 | 24 | ||
25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): |
26 | best_loss = None | 26 | best_loss = None |
27 | best_acc = None | ||
28 | |||
27 | lrs = [] | 29 | lrs = [] |
28 | losses = [] | 30 | losses = [] |
31 | accs = [] | ||
29 | 32 | ||
30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) | 33 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
31 | 34 | ||
@@ -44,6 +47,7 @@ class LRFinder(): | |||
44 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
45 | 48 | ||
46 | avg_loss = AverageMeter() | 49 | avg_loss = AverageMeter() |
50 | avg_acc = AverageMeter() | ||
47 | 51 | ||
48 | self.model.train() | 52 | self.model.train() |
49 | 53 | ||
@@ -71,28 +75,37 @@ class LRFinder(): | |||
71 | 75 | ||
72 | loss, acc, bsz = self.loss_fn(batch) | 76 | loss, acc, bsz = self.loss_fn(batch) |
73 | avg_loss.update(loss.detach_(), bsz) | 77 | avg_loss.update(loss.detach_(), bsz) |
78 | avg_acc.update(acc.detach_(), bsz) | ||
74 | 79 | ||
75 | progress_bar.update(1) | 80 | progress_bar.update(1) |
76 | 81 | ||
77 | lr_scheduler.step() | 82 | lr_scheduler.step() |
78 | 83 | ||
79 | loss = avg_loss.avg.item() | 84 | loss = avg_loss.avg.item() |
85 | acc = avg_acc.avg.item() | ||
86 | |||
80 | if epoch == 0: | 87 | if epoch == 0: |
81 | best_loss = loss | 88 | best_loss = loss |
89 | best_acc = acc | ||
82 | else: | 90 | else: |
83 | if smooth_f > 0: | 91 | if smooth_f > 0: |
84 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] |
85 | if loss < best_loss: | 93 | if loss < best_loss: |
86 | best_loss = loss | 94 | best_loss = loss |
95 | if acc > best_acc: | ||
96 | best_acc = acc | ||
87 | 97 | ||
88 | lr = lr_scheduler.get_last_lr()[0] | 98 | lr = lr_scheduler.get_last_lr()[0] |
89 | 99 | ||
90 | lrs.append(lr) | 100 | lrs.append(lr) |
91 | losses.append(loss) | 101 | losses.append(loss) |
102 | accs.append(acc) | ||
92 | 103 | ||
93 | progress_bar.set_postfix({ | 104 | progress_bar.set_postfix({ |
94 | "loss": loss, | 105 | "loss": loss, |
95 | "best": best_loss, | 106 | "loss/best": best_loss, |
107 | "acc": acc, | ||
108 | "acc/best": best_acc, | ||
96 | "lr": lr, | 109 | "lr": lr, |
97 | }) | 110 | }) |
98 | 111 | ||
@@ -103,20 +116,37 @@ class LRFinder(): | |||
103 | print("Stopping early, the loss has diverged") | 116 | print("Stopping early, the loss has diverged") |
104 | break | 117 | break |
105 | 118 | ||
106 | fig, ax = plt.subplots() | 119 | if skip_end == 0: |
107 | ax.plot(lrs, losses) | 120 | lrs = lrs[skip_start:] |
121 | losses = losses[skip_start:] | ||
122 | accs = accs[skip_start:] | ||
123 | else: | ||
124 | lrs = lrs[skip_start:-skip_end] | ||
125 | losses = losses[skip_start:-skip_end] | ||
126 | accs = accs[skip_start:-skip_end] | ||
127 | |||
128 | fig, ax_loss = plt.subplots() | ||
129 | |||
130 | ax_loss.plot(lrs, losses, color='red', label='Loss') | ||
131 | ax_loss.set_xscale("log") | ||
132 | ax_loss.set_xlabel("Learning rate") | ||
133 | |||
134 | # ax_acc = ax_loss.twinx() | ||
135 | # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') | ||
108 | 136 | ||
109 | print("LR suggestion: steepest gradient") | 137 | print("LR suggestion: steepest gradient") |
110 | min_grad_idx = None | 138 | min_grad_idx = None |
139 | |||
111 | try: | 140 | try: |
112 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 141 | min_grad_idx = (np.gradient(np.array(losses))).argmin() |
113 | except ValueError: | 142 | except ValueError: |
114 | print( | 143 | print( |
115 | "Failed to compute the gradients, there might not be enough points." | 144 | "Failed to compute the gradients, there might not be enough points." |
116 | ) | 145 | ) |
146 | |||
117 | if min_grad_idx is not None: | 147 | if min_grad_idx is not None: |
118 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | 148 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) |
119 | ax.scatter( | 149 | ax_loss.scatter( |
120 | lrs[min_grad_idx], | 150 | lrs[min_grad_idx], |
121 | losses[min_grad_idx], | 151 | losses[min_grad_idx], |
122 | s=75, | 152 | s=75, |
@@ -125,11 +155,7 @@ class LRFinder(): | |||
125 | zorder=3, | 155 | zorder=3, |
126 | label="steepest gradient", | 156 | label="steepest gradient", |
127 | ) | 157 | ) |
128 | ax.legend() | 158 | ax_loss.legend() |
129 | |||
130 | ax.set_xscale("log") | ||
131 | ax.set_xlabel("Learning rate") | ||
132 | ax.set_ylabel("Loss") | ||
133 | 159 | ||
134 | 160 | ||
135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): | 161 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
diff --git a/training/util.py b/training/util.py index a0c15cd..d0f7fcd 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -5,11 +5,6 @@ import torch | |||
5 | from PIL import Image | 5 | from PIL import Image |
6 | 6 | ||
7 | 7 | ||
8 | def freeze_params(params): | ||
9 | for param in params: | ||
10 | param.requires_grad = False | ||
11 | |||
12 | |||
13 | def save_args(basepath: Path, args, extra={}): | 8 | def save_args(basepath: Path, args, extra={}): |
14 | info = {"args": vars(args)} | 9 | info = {"args": vars(args)} |
15 | info["args"].update(extra) | 10 | info["args"].update(extra) |