diff options
| -rw-r--r-- | train_dreambooth.py | 16 | ||||
| -rw-r--r-- | train_lora.py | 10 | ||||
| -rw-r--r-- | train_ti.py | 21 | ||||
| -rw-r--r-- | training/lr.py | 46 | ||||
| -rw-r--r-- | training/util.py | 5 |
5 files changed, 59 insertions, 39 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 325fe90..202d52c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -580,12 +580,10 @@ def main(): | |||
| 580 | 580 | ||
| 581 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 581 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 582 | 582 | ||
| 583 | freeze_params(itertools.chain( | 583 | text_encoder.text_model.encoder.requires_grad_(False) |
| 584 | text_encoder.text_model.encoder.parameters(), | 584 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 585 | text_encoder.text_model.final_layer_norm.parameters(), | 585 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 586 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 586 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
| 587 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 588 | )) | ||
| 589 | 587 | ||
| 590 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 588 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 591 | 589 | ||
| @@ -905,9 +903,7 @@ def main(): | |||
| 905 | if epoch < args.train_text_encoder_epochs: | 903 | if epoch < args.train_text_encoder_epochs: |
| 906 | text_encoder.train() | 904 | text_encoder.train() |
| 907 | elif epoch == args.train_text_encoder_epochs: | 905 | elif epoch == args.train_text_encoder_epochs: |
| 908 | freeze_params(text_encoder.parameters()) | 906 | text_encoder.requires_grad_(False) |
| 909 | |||
| 910 | sample_checkpoint = False | ||
| 911 | 907 | ||
| 912 | for step, batch in enumerate(train_dataloader): | 908 | for step, batch in enumerate(train_dataloader): |
| 913 | with accelerator.accumulate(unet): | 909 | with accelerator.accumulate(unet): |
diff --git a/train_lora.py b/train_lora.py index ffca304..9a42cae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.lora import LoraAttnProcessor | 26 | from training.lora import LoraAttnProcessor |
| 27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -513,11 +513,9 @@ def main(): | |||
| 513 | 513 | ||
| 514 | print(f"Training added text embeddings") | 514 | print(f"Training added text embeddings") |
| 515 | 515 | ||
| 516 | freeze_params(itertools.chain( | 516 | text_encoder.text_model.encoder.requires_grad_(False) |
| 517 | text_encoder.text_model.encoder.parameters(), | 517 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 518 | text_encoder.text_model.final_layer_norm.parameters(), | 518 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 519 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
| 520 | )) | ||
| 521 | 519 | ||
| 522 | index_fixed_tokens = torch.arange(len(tokenizer)) | 520 | index_fixed_tokens = torch.arange(len(tokenizer)) |
| 523 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 521 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] |
diff --git a/train_ti.py b/train_ti.py index 870b2ba..d7696e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -25,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
| 25 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
| 26 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -533,12 +533,10 @@ def main(): | |||
| 533 | 533 | ||
| 534 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 534 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 535 | 535 | ||
| 536 | freeze_params(itertools.chain( | 536 | text_encoder.text_model.encoder.requires_grad_(False) |
| 537 | text_encoder.text_model.encoder.parameters(), | 537 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 538 | text_encoder.text_model.final_layer_norm.parameters(), | 538 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 539 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 539 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
| 540 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 541 | )) | ||
| 542 | 540 | ||
| 543 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 541 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 544 | 542 | ||
| @@ -548,6 +546,9 @@ def main(): | |||
| 548 | args.train_batch_size * accelerator.num_processes | 546 | args.train_batch_size * accelerator.num_processes |
| 549 | ) | 547 | ) |
| 550 | 548 | ||
| 549 | if args.find_lr: | ||
| 550 | args.learning_rate = 1e2 | ||
| 551 | |||
| 551 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 552 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 552 | if args.use_8bit_adam: | 553 | if args.use_8bit_adam: |
| 553 | try: | 554 | try: |
| @@ -715,7 +716,11 @@ def main(): | |||
| 715 | 716 | ||
| 716 | # Keep vae and unet in eval mode as we don't train these | 717 | # Keep vae and unet in eval mode as we don't train these |
| 717 | vae.eval() | 718 | vae.eval() |
| 718 | unet.eval() | 719 | |
| 720 | if args.gradient_checkpointing: | ||
| 721 | unet.train() | ||
| 722 | else: | ||
| 723 | unet.eval() | ||
| 719 | 724 | ||
| 720 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 725 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| 721 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 726 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
diff --git a/training/lr.py b/training/lr.py index 8e558e1..c1fa3a0 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -22,10 +22,13 @@ class LRFinder(): | |||
| 22 | self.model_state = copy.deepcopy(model.state_dict()) | 22 | self.model_state = copy.deepcopy(model.state_dict()) |
| 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
| 24 | 24 | ||
| 25 | def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): |
| 26 | best_loss = None | 26 | best_loss = None |
| 27 | best_acc = None | ||
| 28 | |||
| 27 | lrs = [] | 29 | lrs = [] |
| 28 | losses = [] | 30 | losses = [] |
| 31 | accs = [] | ||
| 29 | 32 | ||
| 30 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) | 33 | lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) |
| 31 | 34 | ||
| @@ -44,6 +47,7 @@ class LRFinder(): | |||
| 44 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 45 | 48 | ||
| 46 | avg_loss = AverageMeter() | 49 | avg_loss = AverageMeter() |
| 50 | avg_acc = AverageMeter() | ||
| 47 | 51 | ||
| 48 | self.model.train() | 52 | self.model.train() |
| 49 | 53 | ||
| @@ -71,28 +75,37 @@ class LRFinder(): | |||
| 71 | 75 | ||
| 72 | loss, acc, bsz = self.loss_fn(batch) | 76 | loss, acc, bsz = self.loss_fn(batch) |
| 73 | avg_loss.update(loss.detach_(), bsz) | 77 | avg_loss.update(loss.detach_(), bsz) |
| 78 | avg_acc.update(acc.detach_(), bsz) | ||
| 74 | 79 | ||
| 75 | progress_bar.update(1) | 80 | progress_bar.update(1) |
| 76 | 81 | ||
| 77 | lr_scheduler.step() | 82 | lr_scheduler.step() |
| 78 | 83 | ||
| 79 | loss = avg_loss.avg.item() | 84 | loss = avg_loss.avg.item() |
| 85 | acc = avg_acc.avg.item() | ||
| 86 | |||
| 80 | if epoch == 0: | 87 | if epoch == 0: |
| 81 | best_loss = loss | 88 | best_loss = loss |
| 89 | best_acc = acc | ||
| 82 | else: | 90 | else: |
| 83 | if smooth_f > 0: | 91 | if smooth_f > 0: |
| 84 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | 92 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] |
| 85 | if loss < best_loss: | 93 | if loss < best_loss: |
| 86 | best_loss = loss | 94 | best_loss = loss |
| 95 | if acc > best_acc: | ||
| 96 | best_acc = acc | ||
| 87 | 97 | ||
| 88 | lr = lr_scheduler.get_last_lr()[0] | 98 | lr = lr_scheduler.get_last_lr()[0] |
| 89 | 99 | ||
| 90 | lrs.append(lr) | 100 | lrs.append(lr) |
| 91 | losses.append(loss) | 101 | losses.append(loss) |
| 102 | accs.append(acc) | ||
| 92 | 103 | ||
| 93 | progress_bar.set_postfix({ | 104 | progress_bar.set_postfix({ |
| 94 | "loss": loss, | 105 | "loss": loss, |
| 95 | "best": best_loss, | 106 | "loss/best": best_loss, |
| 107 | "acc": acc, | ||
| 108 | "acc/best": best_acc, | ||
| 96 | "lr": lr, | 109 | "lr": lr, |
| 97 | }) | 110 | }) |
| 98 | 111 | ||
| @@ -103,20 +116,37 @@ class LRFinder(): | |||
| 103 | print("Stopping early, the loss has diverged") | 116 | print("Stopping early, the loss has diverged") |
| 104 | break | 117 | break |
| 105 | 118 | ||
| 106 | fig, ax = plt.subplots() | 119 | if skip_end == 0: |
| 107 | ax.plot(lrs, losses) | 120 | lrs = lrs[skip_start:] |
| 121 | losses = losses[skip_start:] | ||
| 122 | accs = accs[skip_start:] | ||
| 123 | else: | ||
| 124 | lrs = lrs[skip_start:-skip_end] | ||
| 125 | losses = losses[skip_start:-skip_end] | ||
| 126 | accs = accs[skip_start:-skip_end] | ||
| 127 | |||
| 128 | fig, ax_loss = plt.subplots() | ||
| 129 | |||
| 130 | ax_loss.plot(lrs, losses, color='red', label='Loss') | ||
| 131 | ax_loss.set_xscale("log") | ||
| 132 | ax_loss.set_xlabel("Learning rate") | ||
| 133 | |||
| 134 | # ax_acc = ax_loss.twinx() | ||
| 135 | # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') | ||
| 108 | 136 | ||
| 109 | print("LR suggestion: steepest gradient") | 137 | print("LR suggestion: steepest gradient") |
| 110 | min_grad_idx = None | 138 | min_grad_idx = None |
| 139 | |||
| 111 | try: | 140 | try: |
| 112 | min_grad_idx = (np.gradient(np.array(losses))).argmin() | 141 | min_grad_idx = (np.gradient(np.array(losses))).argmin() |
| 113 | except ValueError: | 142 | except ValueError: |
| 114 | print( | 143 | print( |
| 115 | "Failed to compute the gradients, there might not be enough points." | 144 | "Failed to compute the gradients, there might not be enough points." |
| 116 | ) | 145 | ) |
| 146 | |||
| 117 | if min_grad_idx is not None: | 147 | if min_grad_idx is not None: |
| 118 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) | 148 | print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) |
| 119 | ax.scatter( | 149 | ax_loss.scatter( |
| 120 | lrs[min_grad_idx], | 150 | lrs[min_grad_idx], |
| 121 | losses[min_grad_idx], | 151 | losses[min_grad_idx], |
| 122 | s=75, | 152 | s=75, |
| @@ -125,11 +155,7 @@ class LRFinder(): | |||
| 125 | zorder=3, | 155 | zorder=3, |
| 126 | label="steepest gradient", | 156 | label="steepest gradient", |
| 127 | ) | 157 | ) |
| 128 | ax.legend() | 158 | ax_loss.legend() |
| 129 | |||
| 130 | ax.set_xscale("log") | ||
| 131 | ax.set_xlabel("Learning rate") | ||
| 132 | ax.set_ylabel("Loss") | ||
| 133 | 159 | ||
| 134 | 160 | ||
| 135 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): | 161 | def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): |
diff --git a/training/util.py b/training/util.py index a0c15cd..d0f7fcd 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -5,11 +5,6 @@ import torch | |||
| 5 | from PIL import Image | 5 | from PIL import Image |
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | def freeze_params(params): | ||
| 9 | for param in params: | ||
| 10 | param.requires_grad = False | ||
| 11 | |||
| 12 | |||
| 13 | def save_args(basepath: Path, args, extra={}): | 8 | def save_args(basepath: Path, args, extra={}): |
| 14 | info = {"args": vars(args)} | 9 | info = {"args": vars(args)} |
| 15 | info["args"].update(extra) | 10 | info["args"].update(extra) |
