From 83725794618164210a12843381724252fdd82cc2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Dec 2022 18:08:36 +0100 Subject: Integrated updates from diffusers --- train_dreambooth.py | 16 ++++++---------- train_lora.py | 10 ++++------ train_ti.py | 21 +++++++++++++-------- training/lr.py | 46 ++++++++++++++++++++++++++++++++++++---------- training/util.py | 5 ----- 5 files changed, 59 insertions(+), 39 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 325fe90..202d52c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule from training.ti import patch_trainable_embeddings -from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -580,12 +580,10 @@ def main(): patch_trainable_embeddings(text_encoder, placeholder_token_id) - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - text_encoder.text_model.embeddings.token_embedding.parameters(), - )) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -905,9 +903,7 @@ def main(): if epoch < args.train_text_encoder_epochs: text_encoder.train() elif epoch == args.train_text_encoder_epochs: - freeze_params(text_encoder.parameters()) - - sample_checkpoint = False + text_encoder.requires_grad_(False) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): diff --git a/train_lora.py b/train_lora.py index ffca304..9a42cae 100644 --- a/train_lora.py +++ b/train_lora.py @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.lora import LoraAttnProcessor from training.optimization import get_one_cycle_schedule -from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -513,11 +513,9 @@ def main(): print(f"Training added text embeddings") - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - )) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) index_fixed_tokens = torch.arange(len(tokenizer)) index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] diff --git a/train_ti.py b/train_ti.py index 870b2ba..d7696e5 100644 --- a/train_ti.py +++ b/train_ti.py @@ -25,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.ti import patch_trainable_embeddings -from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -533,12 +533,10 @@ def main(): patch_trainable_embeddings(text_encoder, placeholder_token_id) - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - text_encoder.text_model.embeddings.token_embedding.parameters(), - )) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -548,6 +546,9 @@ def main(): args.train_batch_size * accelerator.num_processes ) + if args.find_lr: + args.learning_rate = 1e2 + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: @@ -715,7 +716,11 @@ def main(): # Keep vae and unet in eval mode as we don't train these vae.eval() - unet.eval() + + if args.gradient_checkpointing: + unet.train() + else: + unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/training/lr.py b/training/lr.py index 8e558e1..c1fa3a0 100644 --- a/training/lr.py +++ b/training/lr.py @@ -22,10 +22,13 @@ class LRFinder(): self.model_state = copy.deepcopy(model.state_dict()) self.optimizer_state = copy.deepcopy(optimizer.state_dict()) - def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): + def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): best_loss = None + best_acc = None + lrs = [] losses = [] + accs = [] lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) @@ -44,6 +47,7 @@ class LRFinder(): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") avg_loss = AverageMeter() + avg_acc = AverageMeter() self.model.train() @@ -71,28 +75,37 @@ class LRFinder(): loss, acc, bsz = self.loss_fn(batch) avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) progress_bar.update(1) lr_scheduler.step() loss = avg_loss.avg.item() + acc = avg_acc.avg.item() + if epoch == 0: best_loss = loss + best_acc = acc else: if smooth_f > 0: loss = smooth_f * loss + (1 - smooth_f) * losses[-1] if loss < best_loss: best_loss = loss + if acc > best_acc: + best_acc = acc lr = lr_scheduler.get_last_lr()[0] lrs.append(lr) losses.append(loss) + accs.append(acc) progress_bar.set_postfix({ "loss": loss, - "best": best_loss, + "loss/best": best_loss, + "acc": acc, + "acc/best": best_acc, "lr": lr, }) @@ -103,20 +116,37 @@ class LRFinder(): print("Stopping early, the loss has diverged") break - fig, ax = plt.subplots() - ax.plot(lrs, losses) + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + + fig, ax_loss = plt.subplots() + + ax_loss.plot(lrs, losses, color='red', label='Loss') + ax_loss.set_xscale("log") + ax_loss.set_xlabel("Learning rate") + + # ax_acc = ax_loss.twinx() + # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') print("LR suggestion: steepest gradient") min_grad_idx = None + try: min_grad_idx = (np.gradient(np.array(losses))).argmin() except ValueError: print( "Failed to compute the gradients, there might not be enough points." ) + if min_grad_idx is not None: print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) - ax.scatter( + ax_loss.scatter( lrs[min_grad_idx], losses[min_grad_idx], s=75, @@ -125,11 +155,7 @@ class LRFinder(): zorder=3, label="steepest gradient", ) - ax.legend() - - ax.set_xscale("log") - ax.set_xlabel("Learning rate") - ax.set_ylabel("Loss") + ax_loss.legend() def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): diff --git a/training/util.py b/training/util.py index a0c15cd..d0f7fcd 100644 --- a/training/util.py +++ b/training/util.py @@ -5,11 +5,6 @@ import torch from PIL import Image -def freeze_params(params): - for param in params: - param.requires_grad = False - - def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) -- cgit v1.2.3-70-g09d2