diff options
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index b5ec2fc..eeee424 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -16,7 +16,7 @@ from accelerate import Accelerator | |||
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 20 | from diffusers.training_utils import EMAModel | 20 | from diffusers.training_utils import EMAModel |
| 21 | from PIL import Image | 21 | from PIL import Image |
| 22 | from tqdm.auto import tqdm | 22 | from tqdm.auto import tqdm |
| @@ -118,7 +118,7 @@ def parse_args(): | |||
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| 119 | "--max_train_steps", | 119 | "--max_train_steps", |
| 120 | type=int, | 120 | type=int, |
| 121 | default=2300, | 121 | default=1300, |
| 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 123 | ) | 123 | ) |
| 124 | parser.add_argument( | 124 | parser.add_argument( |
| @@ -317,6 +317,13 @@ def parse_args(): | |||
| 317 | return args | 317 | return args |
| 318 | 318 | ||
| 319 | 319 | ||
| 320 | def save_args(basepath: Path, args, extra={}): | ||
| 321 | info = {"args": vars(args)} | ||
| 322 | info["args"].update(extra) | ||
| 323 | with open(basepath.joinpath("args.json"), "w") as f: | ||
| 324 | json.dump(info, f, indent=4) | ||
| 325 | |||
| 326 | |||
| 320 | def freeze_params(params): | 327 | def freeze_params(params): |
| 321 | for param in params: | 328 | for param in params: |
| 322 | param.requires_grad = False | 329 | param.requires_grad = False |
| @@ -503,6 +510,8 @@ def main(): | |||
| 503 | 510 | ||
| 504 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 511 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 505 | 512 | ||
| 513 | save_args(basepath, args) | ||
| 514 | |||
| 506 | # If passed along, set the training seed now. | 515 | # If passed along, set the training seed now. |
| 507 | if args.seed is not None: | 516 | if args.seed is not None: |
| 508 | set_seed(args.seed) | 517 | set_seed(args.seed) |
| @@ -706,12 +715,21 @@ def main(): | |||
| 706 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 715 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 707 | overrode_max_train_steps = True | 716 | overrode_max_train_steps = True |
| 708 | 717 | ||
| 709 | lr_scheduler = get_scheduler( | 718 | if args.lr_scheduler == "cosine_with_restarts": |
| 710 | args.lr_scheduler, | 719 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 711 | optimizer=optimizer, | 720 | args.lr_scheduler, |
| 712 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 721 | optimizer=optimizer, |
| 713 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 722 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 714 | ) | 723 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 724 | num_cycles=num_update_steps_per_epoch, | ||
| 725 | ) | ||
| 726 | else: | ||
| 727 | lr_scheduler = get_scheduler( | ||
| 728 | args.lr_scheduler, | ||
| 729 | optimizer=optimizer, | ||
| 730 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
| 731 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 732 | ) | ||
| 715 | 733 | ||
| 716 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 734 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 717 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 735 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
