summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py33
1 files changed, 21 insertions, 12 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 6627f1f..2109d13 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -16,7 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
@@ -303,10 +303,10 @@ def freeze_params(params):
303 param.requires_grad = False 303 param.requires_grad = False
304 304
305 305
306def save_resume_file(basepath, args, extra={}): 306def save_args(basepath: Path, args, extra={}):
307 info = {"args": vars(args)} 307 info = {"args": vars(args)}
308 info["args"].update(extra) 308 info["args"].update(extra)
309 with open(f"{basepath}/resume.json", "w") as f: 309 with open(basepath.joinpath("args.json"), "w") as f:
310 json.dump(info, f, indent=4) 310 json.dump(info, f, indent=4)
311 311
312 312
@@ -660,12 +660,21 @@ def main():
660 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 660 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
661 overrode_max_train_steps = True 661 overrode_max_train_steps = True
662 662
663 lr_scheduler = get_scheduler( 663 if args.lr_scheduler == "cosine_with_restarts":
664 args.lr_scheduler, 664 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
665 optimizer=optimizer, 665 args.lr_scheduler,
666 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 666 optimizer=optimizer,
667 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 667 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
668 ) 668 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
669 num_cycles=num_update_steps_per_epoch,
670 )
671 else:
672 lr_scheduler = get_scheduler(
673 args.lr_scheduler,
674 optimizer=optimizer,
675 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
676 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
677 )
669 678
670 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 679 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
671 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 680 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
@@ -827,7 +836,7 @@ def main():
827 global_progress_bar.clear() 836 global_progress_bar.clear()
828 837
829 checkpointer.checkpoint(global_step + global_step_offset, "training") 838 checkpointer.checkpoint(global_step + global_step_offset, "training")
830 save_resume_file(basepath, args, { 839 save_args(basepath, args, {
831 "global_step": global_step + global_step_offset, 840 "global_step": global_step + global_step_offset,
832 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 841 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
833 }) 842 })
@@ -901,7 +910,7 @@ def main():
901 if accelerator.is_main_process: 910 if accelerator.is_main_process:
902 print("Finished! Saving final checkpoint and resume state.") 911 print("Finished! Saving final checkpoint and resume state.")
903 checkpointer.checkpoint(global_step + global_step_offset, "end") 912 checkpointer.checkpoint(global_step + global_step_offset, "end")
904 save_resume_file(basepath, args, { 913 save_args(basepath, args, {
905 "global_step": global_step + global_step_offset, 914 "global_step": global_step + global_step_offset,
906 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 915 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
907 }) 916 })
@@ -911,7 +920,7 @@ def main():
911 if accelerator.is_main_process: 920 if accelerator.is_main_process:
912 print("Interrupted, saving checkpoint and resume state...") 921 print("Interrupted, saving checkpoint and resume state...")
913 checkpointer.checkpoint(global_step + global_step_offset, "end") 922 checkpointer.checkpoint(global_step + global_step_offset, "end")
914 save_resume_file(basepath, args, { 923 save_args(basepath, args, {
915 "global_step": global_step + global_step_offset, 924 "global_step": global_step + global_step_offset,
916 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 925 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
917 }) 926 })