summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py16
-rw-r--r--train_lora.py10
-rw-r--r--train_ti.py21
-rw-r--r--training/lr.py46
-rw-r--r--training/util.py5
5 files changed, 59 insertions, 39 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 325fe90..202d52c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -580,12 +580,10 @@ def main():
580 580
581 patch_trainable_embeddings(text_encoder, placeholder_token_id) 581 patch_trainable_embeddings(text_encoder, placeholder_token_id)
582 582
583 freeze_params(itertools.chain( 583 text_encoder.text_model.encoder.requires_grad_(False)
584 text_encoder.text_model.encoder.parameters(), 584 text_encoder.text_model.final_layer_norm.requires_grad_(False)
585 text_encoder.text_model.final_layer_norm.parameters(), 585 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
586 text_encoder.text_model.embeddings.position_embedding.parameters(), 586 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
587 text_encoder.text_model.embeddings.token_embedding.parameters(),
588 ))
589 587
590 prompt_processor = PromptProcessor(tokenizer, text_encoder) 588 prompt_processor = PromptProcessor(tokenizer, text_encoder)
591 589
@@ -905,9 +903,7 @@ def main():
905 if epoch < args.train_text_encoder_epochs: 903 if epoch < args.train_text_encoder_epochs:
906 text_encoder.train() 904 text_encoder.train()
907 elif epoch == args.train_text_encoder_epochs: 905 elif epoch == args.train_text_encoder_epochs:
908 freeze_params(text_encoder.parameters()) 906 text_encoder.requires_grad_(False)
909
910 sample_checkpoint = False
911 907
912 for step, batch in enumerate(train_dataloader): 908 for step, batch in enumerate(train_dataloader):
913 with accelerator.accumulate(unet): 909 with accelerator.accumulate(unet):
diff --git a/train_lora.py b/train_lora.py
index ffca304..9a42cae 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.lora import LoraAttnProcessor 26from training.lora import LoraAttnProcessor
27from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
28from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -513,11 +513,9 @@ def main():
513 513
514 print(f"Training added text embeddings") 514 print(f"Training added text embeddings")
515 515
516 freeze_params(itertools.chain( 516 text_encoder.text_model.encoder.requires_grad_(False)
517 text_encoder.text_model.encoder.parameters(), 517 text_encoder.text_model.final_layer_norm.requires_grad_(False)
518 text_encoder.text_model.final_layer_norm.parameters(), 518 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
519 text_encoder.text_model.embeddings.position_embedding.parameters(),
520 ))
521 519
522 index_fixed_tokens = torch.arange(len(tokenizer)) 520 index_fixed_tokens = torch.arange(len(tokenizer))
523 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] 521 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
diff --git a/train_ti.py b/train_ti.py
index 870b2ba..d7696e5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -25,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem
25from training.optimization import get_one_cycle_schedule 25from training.optimization import get_one_cycle_schedule
26from training.lr import LRFinder 26from training.lr import LRFinder
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -533,12 +533,10 @@ def main():
533 533
534 patch_trainable_embeddings(text_encoder, placeholder_token_id) 534 patch_trainable_embeddings(text_encoder, placeholder_token_id)
535 535
536 freeze_params(itertools.chain( 536 text_encoder.text_model.encoder.requires_grad_(False)
537 text_encoder.text_model.encoder.parameters(), 537 text_encoder.text_model.final_layer_norm.requires_grad_(False)
538 text_encoder.text_model.final_layer_norm.parameters(), 538 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
539 text_encoder.text_model.embeddings.position_embedding.parameters(), 539 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
540 text_encoder.text_model.embeddings.token_embedding.parameters(),
541 ))
542 540
543 prompt_processor = PromptProcessor(tokenizer, text_encoder) 541 prompt_processor = PromptProcessor(tokenizer, text_encoder)
544 542
@@ -548,6 +546,9 @@ def main():
548 args.train_batch_size * accelerator.num_processes 546 args.train_batch_size * accelerator.num_processes
549 ) 547 )
550 548
549 if args.find_lr:
550 args.learning_rate = 1e2
551
551 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 552 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
552 if args.use_8bit_adam: 553 if args.use_8bit_adam:
553 try: 554 try:
@@ -715,7 +716,11 @@ def main():
715 716
716 # Keep vae and unet in eval mode as we don't train these 717 # Keep vae and unet in eval mode as we don't train these
717 vae.eval() 718 vae.eval()
718 unet.eval() 719
720 if args.gradient_checkpointing:
721 unet.train()
722 else:
723 unet.eval()
719 724
720 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 725 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
721 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 726 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
diff --git a/training/lr.py b/training/lr.py
index 8e558e1..c1fa3a0 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -22,10 +22,13 @@ class LRFinder():
22 self.model_state = copy.deepcopy(model.state_dict()) 22 self.model_state = copy.deepcopy(model.state_dict())
23 self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 23 self.optimizer_state = copy.deepcopy(optimizer.state_dict())
24 24
25 def run(self, min_lr, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): 25 def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5):
26 best_loss = None 26 best_loss = None
27 best_acc = None
28
27 lrs = [] 29 lrs = []
28 losses = [] 30 losses = []
31 accs = []
29 32
30 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) 33 lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs)
31 34
@@ -44,6 +47,7 @@ class LRFinder():
44 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 47 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
45 48
46 avg_loss = AverageMeter() 49 avg_loss = AverageMeter()
50 avg_acc = AverageMeter()
47 51
48 self.model.train() 52 self.model.train()
49 53
@@ -71,28 +75,37 @@ class LRFinder():
71 75
72 loss, acc, bsz = self.loss_fn(batch) 76 loss, acc, bsz = self.loss_fn(batch)
73 avg_loss.update(loss.detach_(), bsz) 77 avg_loss.update(loss.detach_(), bsz)
78 avg_acc.update(acc.detach_(), bsz)
74 79
75 progress_bar.update(1) 80 progress_bar.update(1)
76 81
77 lr_scheduler.step() 82 lr_scheduler.step()
78 83
79 loss = avg_loss.avg.item() 84 loss = avg_loss.avg.item()
85 acc = avg_acc.avg.item()
86
80 if epoch == 0: 87 if epoch == 0:
81 best_loss = loss 88 best_loss = loss
89 best_acc = acc
82 else: 90 else:
83 if smooth_f > 0: 91 if smooth_f > 0:
84 loss = smooth_f * loss + (1 - smooth_f) * losses[-1] 92 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
85 if loss < best_loss: 93 if loss < best_loss:
86 best_loss = loss 94 best_loss = loss
95 if acc > best_acc:
96 best_acc = acc
87 97
88 lr = lr_scheduler.get_last_lr()[0] 98 lr = lr_scheduler.get_last_lr()[0]
89 99
90 lrs.append(lr) 100 lrs.append(lr)
91 losses.append(loss) 101 losses.append(loss)
102 accs.append(acc)
92 103
93 progress_bar.set_postfix({ 104 progress_bar.set_postfix({
94 "loss": loss, 105 "loss": loss,
95 "best": best_loss, 106 "loss/best": best_loss,
107 "acc": acc,
108 "acc/best": best_acc,
96 "lr": lr, 109 "lr": lr,
97 }) 110 })
98 111
@@ -103,20 +116,37 @@ class LRFinder():
103 print("Stopping early, the loss has diverged") 116 print("Stopping early, the loss has diverged")
104 break 117 break
105 118
106 fig, ax = plt.subplots() 119 if skip_end == 0:
107 ax.plot(lrs, losses) 120 lrs = lrs[skip_start:]
121 losses = losses[skip_start:]
122 accs = accs[skip_start:]
123 else:
124 lrs = lrs[skip_start:-skip_end]
125 losses = losses[skip_start:-skip_end]
126 accs = accs[skip_start:-skip_end]
127
128 fig, ax_loss = plt.subplots()
129
130 ax_loss.plot(lrs, losses, color='red', label='Loss')
131 ax_loss.set_xscale("log")
132 ax_loss.set_xlabel("Learning rate")
133
134 # ax_acc = ax_loss.twinx()
135 # ax_acc.plot(lrs, accs, color='blue', label='Accuracy')
108 136
109 print("LR suggestion: steepest gradient") 137 print("LR suggestion: steepest gradient")
110 min_grad_idx = None 138 min_grad_idx = None
139
111 try: 140 try:
112 min_grad_idx = (np.gradient(np.array(losses))).argmin() 141 min_grad_idx = (np.gradient(np.array(losses))).argmin()
113 except ValueError: 142 except ValueError:
114 print( 143 print(
115 "Failed to compute the gradients, there might not be enough points." 144 "Failed to compute the gradients, there might not be enough points."
116 ) 145 )
146
117 if min_grad_idx is not None: 147 if min_grad_idx is not None:
118 print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) 148 print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
119 ax.scatter( 149 ax_loss.scatter(
120 lrs[min_grad_idx], 150 lrs[min_grad_idx],
121 losses[min_grad_idx], 151 losses[min_grad_idx],
122 s=75, 152 s=75,
@@ -125,11 +155,7 @@ class LRFinder():
125 zorder=3, 155 zorder=3,
126 label="steepest gradient", 156 label="steepest gradient",
127 ) 157 )
128 ax.legend() 158 ax_loss.legend()
129
130 ax.set_xscale("log")
131 ax.set_xlabel("Learning rate")
132 ax.set_ylabel("Loss")
133 159
134 160
135def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): 161def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1):
diff --git a/training/util.py b/training/util.py
index a0c15cd..d0f7fcd 100644
--- a/training/util.py
+++ b/training/util.py
@@ -5,11 +5,6 @@ import torch
5from PIL import Image 5from PIL import Image
6 6
7 7
8def freeze_params(params):
9 for param in params:
10 param.requires_grad = False
11
12
13def save_args(basepath: Path, args, extra={}): 8def save_args(basepath: Path, args, extra={}):
14 info = {"args": vars(args)} 9 info = {"args": vars(args)}
15 info["args"].update(extra) 10 info["args"].update(extra)