summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py34
1 files changed, 26 insertions, 8 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index b5ec2fc..eeee424 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -16,7 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from diffusers.training_utils import EMAModel 20from diffusers.training_utils import EMAModel
21from PIL import Image 21from PIL import Image
22from tqdm.auto import tqdm 22from tqdm.auto import tqdm
@@ -118,7 +118,7 @@ def parse_args():
118 parser.add_argument( 118 parser.add_argument(
119 "--max_train_steps", 119 "--max_train_steps",
120 type=int, 120 type=int,
121 default=2300, 121 default=1300,
122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123 ) 123 )
124 parser.add_argument( 124 parser.add_argument(
@@ -317,6 +317,13 @@ def parse_args():
317 return args 317 return args
318 318
319 319
320def save_args(basepath: Path, args, extra={}):
321 info = {"args": vars(args)}
322 info["args"].update(extra)
323 with open(basepath.joinpath("args.json"), "w") as f:
324 json.dump(info, f, indent=4)
325
326
320def freeze_params(params): 327def freeze_params(params):
321 for param in params: 328 for param in params:
322 param.requires_grad = False 329 param.requires_grad = False
@@ -503,6 +510,8 @@ def main():
503 510
504 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 511 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
505 512
513 save_args(basepath, args)
514
506 # If passed along, set the training seed now. 515 # If passed along, set the training seed now.
507 if args.seed is not None: 516 if args.seed is not None:
508 set_seed(args.seed) 517 set_seed(args.seed)
@@ -706,12 +715,21 @@ def main():
706 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 715 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
707 overrode_max_train_steps = True 716 overrode_max_train_steps = True
708 717
709 lr_scheduler = get_scheduler( 718 if args.lr_scheduler == "cosine_with_restarts":
710 args.lr_scheduler, 719 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
711 optimizer=optimizer, 720 args.lr_scheduler,
712 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 721 optimizer=optimizer,
713 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 722 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
714 ) 723 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
724 num_cycles=num_update_steps_per_epoch,
725 )
726 else:
727 lr_scheduler = get_scheduler(
728 args.lr_scheduler,
729 optimizer=optimizer,
730 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
731 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
732 )
715 733
716 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 734 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
717 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 735 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler