diff options
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index b5ec2fc..eeee424 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -16,7 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from diffusers.training_utils import EMAModel | 20 | from diffusers.training_utils import EMAModel |
21 | from PIL import Image | 21 | from PIL import Image |
22 | from tqdm.auto import tqdm | 22 | from tqdm.auto import tqdm |
@@ -118,7 +118,7 @@ def parse_args(): | |||
118 | parser.add_argument( | 118 | parser.add_argument( |
119 | "--max_train_steps", | 119 | "--max_train_steps", |
120 | type=int, | 120 | type=int, |
121 | default=2300, | 121 | default=1300, |
122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
123 | ) | 123 | ) |
124 | parser.add_argument( | 124 | parser.add_argument( |
@@ -317,6 +317,13 @@ def parse_args(): | |||
317 | return args | 317 | return args |
318 | 318 | ||
319 | 319 | ||
320 | def save_args(basepath: Path, args, extra={}): | ||
321 | info = {"args": vars(args)} | ||
322 | info["args"].update(extra) | ||
323 | with open(basepath.joinpath("args.json"), "w") as f: | ||
324 | json.dump(info, f, indent=4) | ||
325 | |||
326 | |||
320 | def freeze_params(params): | 327 | def freeze_params(params): |
321 | for param in params: | 328 | for param in params: |
322 | param.requires_grad = False | 329 | param.requires_grad = False |
@@ -503,6 +510,8 @@ def main(): | |||
503 | 510 | ||
504 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 511 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
505 | 512 | ||
513 | save_args(basepath, args) | ||
514 | |||
506 | # If passed along, set the training seed now. | 515 | # If passed along, set the training seed now. |
507 | if args.seed is not None: | 516 | if args.seed is not None: |
508 | set_seed(args.seed) | 517 | set_seed(args.seed) |
@@ -706,12 +715,21 @@ def main(): | |||
706 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 715 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
707 | overrode_max_train_steps = True | 716 | overrode_max_train_steps = True |
708 | 717 | ||
709 | lr_scheduler = get_scheduler( | 718 | if args.lr_scheduler == "cosine_with_restarts": |
710 | args.lr_scheduler, | 719 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
711 | optimizer=optimizer, | 720 | args.lr_scheduler, |
712 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 721 | optimizer=optimizer, |
713 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 722 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
714 | ) | 723 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
724 | num_cycles=num_update_steps_per_epoch, | ||
725 | ) | ||
726 | else: | ||
727 | lr_scheduler = get_scheduler( | ||
728 | args.lr_scheduler, | ||
729 | optimizer=optimizer, | ||
730 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
731 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
732 | ) | ||
715 | 733 | ||
716 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 734 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
717 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 735 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |