diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 33 |
1 files changed, 21 insertions, 12 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 6627f1f..2109d13 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -16,7 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from PIL import Image | 20 | from PIL import Image |
21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
@@ -303,10 +303,10 @@ def freeze_params(params): | |||
303 | param.requires_grad = False | 303 | param.requires_grad = False |
304 | 304 | ||
305 | 305 | ||
306 | def save_resume_file(basepath, args, extra={}): | 306 | def save_args(basepath: Path, args, extra={}): |
307 | info = {"args": vars(args)} | 307 | info = {"args": vars(args)} |
308 | info["args"].update(extra) | 308 | info["args"].update(extra) |
309 | with open(f"{basepath}/resume.json", "w") as f: | 309 | with open(basepath.joinpath("args.json"), "w") as f: |
310 | json.dump(info, f, indent=4) | 310 | json.dump(info, f, indent=4) |
311 | 311 | ||
312 | 312 | ||
@@ -660,12 +660,21 @@ def main(): | |||
660 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 660 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
661 | overrode_max_train_steps = True | 661 | overrode_max_train_steps = True |
662 | 662 | ||
663 | lr_scheduler = get_scheduler( | 663 | if args.lr_scheduler == "cosine_with_restarts": |
664 | args.lr_scheduler, | 664 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
665 | optimizer=optimizer, | 665 | args.lr_scheduler, |
666 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 666 | optimizer=optimizer, |
667 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 667 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
668 | ) | 668 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
669 | num_cycles=num_update_steps_per_epoch, | ||
670 | ) | ||
671 | else: | ||
672 | lr_scheduler = get_scheduler( | ||
673 | args.lr_scheduler, | ||
674 | optimizer=optimizer, | ||
675 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
676 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
677 | ) | ||
669 | 678 | ||
670 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 679 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
671 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 680 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
@@ -827,7 +836,7 @@ def main(): | |||
827 | global_progress_bar.clear() | 836 | global_progress_bar.clear() |
828 | 837 | ||
829 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 838 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
830 | save_resume_file(basepath, args, { | 839 | save_args(basepath, args, { |
831 | "global_step": global_step + global_step_offset, | 840 | "global_step": global_step + global_step_offset, |
832 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 841 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
833 | }) | 842 | }) |
@@ -901,7 +910,7 @@ def main(): | |||
901 | if accelerator.is_main_process: | 910 | if accelerator.is_main_process: |
902 | print("Finished! Saving final checkpoint and resume state.") | 911 | print("Finished! Saving final checkpoint and resume state.") |
903 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 912 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
904 | save_resume_file(basepath, args, { | 913 | save_args(basepath, args, { |
905 | "global_step": global_step + global_step_offset, | 914 | "global_step": global_step + global_step_offset, |
906 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 915 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
907 | }) | 916 | }) |
@@ -911,7 +920,7 @@ def main(): | |||
911 | if accelerator.is_main_process: | 920 | if accelerator.is_main_process: |
912 | print("Interrupted, saving checkpoint and resume state...") | 921 | print("Interrupted, saving checkpoint and resume state...") |
913 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 922 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
914 | save_resume_file(basepath, args, { | 923 | save_args(basepath, args, { |
915 | "global_step": global_step + global_step_offset, | 924 | "global_step": global_step + global_step_offset, |
916 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 925 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
917 | }) | 926 | }) |