summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-27 11:48:33 +0100
committerVolpeon <git@volpeon.ink>2022-12-27 11:48:33 +0100
commit30098b1d611853c0d3a4687d84582e1c1cf1b938 (patch)
tree94817d6ccd2fb7a8a58fb8a6ef6543b6db5b9a51 /train_ti.py
parentAdded learning rate finder (diff)
downloadtextual-inversion-diff-30098b1d611853c0d3a4687d84582e1c1cf1b938.tar.gz
textual-inversion-diff-30098b1d611853c0d3a4687d84582e1c1cf1b938.tar.bz2
textual-inversion-diff-30098b1d611853c0d3a4687d84582e1c1cf1b938.zip
Added validation phase to learn rate finder
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py
index ab00b60..32f44f4 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -14,6 +14,7 @@ from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt
17from tqdm.auto import tqdm 18from tqdm.auto import tqdm
18from transformers import CLIPTextModel, CLIPTokenizer 19from transformers import CLIPTextModel, CLIPTokenizer
19from slugify import slugify 20from slugify import slugify
@@ -451,6 +452,7 @@ def main():
451 global_step_offset = args.global_step 452 global_step_offset = args.global_step
452 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 453 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
453 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 454 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
455 basepath.mkdir(parents=True, exist_ok=True)
454 456
455 if args.find_lr: 457 if args.find_lr:
456 accelerator = Accelerator( 458 accelerator = Accelerator(
@@ -458,8 +460,6 @@ def main():
458 mixed_precision=args.mixed_precision 460 mixed_precision=args.mixed_precision
459 ) 461 )
460 else: 462 else:
461 basepath.mkdir(parents=True, exist_ok=True)
462
463 accelerator = Accelerator( 463 accelerator = Accelerator(
464 log_with=LoggerType.TENSORBOARD, 464 log_with=LoggerType.TENSORBOARD,
465 logging_dir=f"{basepath}", 465 logging_dir=f"{basepath}",
@@ -782,8 +782,12 @@ def main():
782 return loss, acc, bsz 782 return loss, acc, bsz
783 783
784 if args.find_lr: 784 if args.find_lr:
785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, loop) 785 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
786 lr_finder.run() 786 lr_finder.run(num_train_steps=2)
787
788 plt.savefig(basepath.joinpath("lr.png"))
789 plt.close()
790
787 quit() 791 quit()
788 792
789 # We need to initialize the trackers we use, and also store our configuration. 793 # We need to initialize the trackers we use, and also store our configuration.