summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py12
-rw-r--r--training/lr.py34
2 files changed, 31 insertions, 15 deletions
diff --git a/train_ti.py b/train_ti.py
index ab00b60..32f44f4 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -14,6 +14,7 @@ from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt
17from tqdm.auto import tqdm 18from tqdm.auto import tqdm
18from transformers import CLIPTextModel, CLIPTokenizer 19from transformers import CLIPTextModel, CLIPTokenizer
19from slugify import slugify 20from slugify import slugify
@@ -451,6 +452,7 @@ def main():
451 global_step_offset = args.global_step 452 global_step_offset = args.global_step
452 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 453 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
453 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 454 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
455 basepath.mkdir(parents=True, exist_ok=True)
454 456
455 if args.find_lr: 457 if args.find_lr:
456 accelerator = Accelerator( 458 accelerator = Accelerator(
@@ -458,8 +460,6 @@ def main():
458 mixed_precision=args.mixed_precision 460 mixed_precision=args.mixed_precision
459 ) 461 )
460 else: 462 else:
461 basepath.mkdir(parents=True, exist_ok=True)
462
463 accelerator = Accelerator( 463 accelerator = Accelerator(
464 log_with=LoggerType.TENSORBOARD, 464 log_with=LoggerType.TENSORBOARD,
465 logging_dir=f"{basepath}", 465 logging_dir=f"{basepath}",
@@ -782,8 +782,12 @@ def main():
782 return loss, acc, bsz 782 return loss, acc, bsz
783 783
784 if args.find_lr: 784 if args.find_lr:
785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) 785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
786 lr_finder.run() 786 lr_finder.run(num_train_steps=2)
787
788 plt.savefig(basepath.joinpath("lr.png"))
789 plt.close()
790
787 quit() 791 quit()
788 792
789 # We need to initialize the trackers we use, and also store our configuration. 793 # We need to initialize the trackers we use, and also store our configuration.
diff --git a/training/lr.py b/training/lr.py
index dd37baa..5343f24 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,20 +1,22 @@
1import matplotlib.pyplot as plt
1import numpy as np 2import numpy as np
3import torch
2from torch.optim.lr_scheduler import LambdaLR 4from torch.optim.lr_scheduler import LambdaLR
3from tqdm.auto import tqdm 5from tqdm.auto import tqdm
4import matplotlib.pyplot as plt
5 6
6from training.util import AverageMeter 7from training.util import AverageMeter
7 8
8 9
9class LRFinder(): 10class LRFinder():
10 def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): 11 def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn):
11 self.accelerator = accelerator 12 self.accelerator = accelerator
12 self.model = model 13 self.model = model
13 self.optimizer = optimizer 14 self.optimizer = optimizer
14 self.train_dataloader = train_dataloader 15 self.train_dataloader = train_dataloader
16 self.val_dataloader = val_dataloader
15 self.loss_fn = loss_fn 17 self.loss_fn = loss_fn
16 18
17 def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): 19 def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5):
18 best_loss = None 20 best_loss = None
19 lrs = [] 21 lrs = []
20 losses = [] 22 losses = []
@@ -22,7 +24,7 @@ class LRFinder():
22 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) 24 lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs)
23 25
24 progress_bar = tqdm( 26 progress_bar = tqdm(
25 range(num_epochs * num_steps), 27 range(num_epochs * (num_train_steps + num_val_steps)),
26 disable=not self.accelerator.is_local_main_process, 28 disable=not self.accelerator.is_local_main_process,
27 dynamic_ncols=True 29 dynamic_ncols=True
28 ) 30 )
@@ -33,6 +35,8 @@ class LRFinder():
33 35
34 avg_loss = AverageMeter() 36 avg_loss = AverageMeter()
35 37
38 self.model.train()
39
36 for step, batch in enumerate(self.train_dataloader): 40 for step, batch in enumerate(self.train_dataloader):
37 with self.accelerator.accumulate(self.model): 41 with self.accelerator.accumulate(self.model):
38 loss, acc, bsz = self.loss_fn(batch) 42 loss, acc, bsz = self.loss_fn(batch)
@@ -42,13 +46,24 @@ class LRFinder():
42 self.optimizer.step() 46 self.optimizer.step()
43 self.optimizer.zero_grad(set_to_none=True) 47 self.optimizer.zero_grad(set_to_none=True)
44 48
45 avg_loss.update(loss.detach_(), bsz) 49 if self.accelerator.sync_gradients:
50 progress_bar.update(1)
46 51
47 if step >= num_steps: 52 if step >= num_train_steps:
48 break 53 break
49 54
50 if self.accelerator.sync_gradients: 55 self.model.eval()
51 progress_bar.update(1) 56
57 with torch.inference_mode():
58 for step, batch in enumerate(self.val_dataloader):
59 loss, acc, bsz = self.loss_fn(batch)
60 avg_loss.update(loss.detach_(), bsz)
61
62 if self.accelerator.sync_gradients:
63 progress_bar.update(1)
64
65 if step >= num_val_steps:
66 break
52 67
53 lr_scheduler.step() 68 lr_scheduler.step()
54 69
@@ -104,9 +119,6 @@ class LRFinder():
104 ax.set_xlabel("Learning rate") 119 ax.set_xlabel("Learning rate")
105 ax.set_ylabel("Loss") 120 ax.set_ylabel("Loss")
106 121
107 if fig is not None:
108 plt.show()
109
110 122
111def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): 123def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1):
112 def lr_lambda(current_epoch: int): 124 def lr_lambda(current_epoch: int):