From fcbc11be99c011ab1003451ef72c95ca587902d8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Oct 2022 18:42:27 +0200 Subject: Update --- textual_inversion.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 6627f1f..2109d13 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -16,7 +16,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler +from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -303,10 +303,10 @@ def freeze_params(params): param.requires_grad = False -def save_resume_file(basepath, args, extra={}): +def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) - with open(f"{basepath}/resume.json", "w") as f: + with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) @@ -660,12 +660,21 @@ def main(): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) + if args.lr_scheduler == "cosine_with_restarts": + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=num_update_steps_per_epoch, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler @@ -827,7 +836,7 @@ def main(): global_progress_bar.clear() checkpointer.checkpoint(global_step + global_step_offset, "training") - save_resume_file(basepath, args, { + save_args(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) @@ -901,7 +910,7 @@ def main(): if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") checkpointer.checkpoint(global_step + global_step_offset, "end") - save_resume_file(basepath, args, { + save_args(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) @@ -911,7 +920,7 @@ def main(): if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") checkpointer.checkpoint(global_step + global_step_offset, "end") - save_resume_file(basepath, args, { + save_args(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) -- cgit v1.2.3-54-g00ecf