diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-16 14:39:39 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-16 14:39:39 +0200 |
| commit | dee4c7135754543f1eb7ea616ee3847d34a85b51 (patch) | |
| tree | 4064b44bb79e499cf6a8f1ec38a83a4889f067a7 /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.gz textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.bz2 textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 41 |
1 files changed, 32 insertions, 9 deletions
diff --git a/dreambooth.py b/dreambooth.py index 1ba8dc0..9e2645b 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -15,7 +15,7 @@ from accelerate import Accelerator | |||
| 15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
| 16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
| 18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 19 | from diffusers.training_utils import EMAModel | 19 | from diffusers.training_utils import EMAModel |
| 20 | from PIL import Image | 20 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| @@ -150,10 +150,16 @@ def parse_args(): | |||
| 150 | parser.add_argument( | 150 | parser.add_argument( |
| 151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
| 152 | type=int, | 152 | type=int, |
| 153 | default=500, | 153 | default=300, |
| 154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
| 155 | ) | 155 | ) |
| 156 | parser.add_argument( | 156 | parser.add_argument( |
| 157 | "--lr_cycles", | ||
| 158 | type=int, | ||
| 159 | default=2, | ||
| 160 | help="Number of restart cycles in the lr scheduler." | ||
| 161 | ) | ||
| 162 | parser.add_argument( | ||
| 157 | "--use_ema", | 163 | "--use_ema", |
| 158 | action="store_true", | 164 | action="store_true", |
| 159 | default=True, | 165 | default=True, |
| @@ -167,7 +173,7 @@ def parse_args(): | |||
| 167 | parser.add_argument( | 173 | parser.add_argument( |
| 168 | "--ema_power", | 174 | "--ema_power", |
| 169 | type=float, | 175 | type=float, |
| 170 | default=6 / 7 | 176 | default=9 / 10 |
| 171 | ) | 177 | ) |
| 172 | parser.add_argument( | 178 | parser.add_argument( |
| 173 | "--ema_max_decay", | 179 | "--ema_max_decay", |
| @@ -296,6 +302,13 @@ def parse_args(): | |||
| 296 | return args | 302 | return args |
| 297 | 303 | ||
| 298 | 304 | ||
| 305 | def save_args(basepath: Path, args, extra={}): | ||
| 306 | info = {"args": vars(args)} | ||
| 307 | info["args"].update(extra) | ||
| 308 | with open(basepath.joinpath("args.json"), "w") as f: | ||
| 309 | json.dump(info, f, indent=4) | ||
| 310 | |||
| 311 | |||
| 299 | def freeze_params(params): | 312 | def freeze_params(params): |
| 300 | for param in params: | 313 | for param in params: |
| 301 | param.requires_grad = False | 314 | param.requires_grad = False |
| @@ -455,6 +468,8 @@ def main(): | |||
| 455 | 468 | ||
| 456 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 469 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 457 | 470 | ||
| 471 | save_args(basepath, args) | ||
| 472 | |||
| 458 | # If passed along, set the training seed now. | 473 | # If passed along, set the training seed now. |
| 459 | if args.seed is not None: | 474 | if args.seed is not None: |
| 460 | set_seed(args.seed) | 475 | set_seed(args.seed) |
| @@ -614,12 +629,20 @@ def main(): | |||
| 614 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 629 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 615 | overrode_max_train_steps = True | 630 | overrode_max_train_steps = True |
| 616 | 631 | ||
| 617 | lr_scheduler = get_scheduler( | 632 | if args.lr_scheduler == "cosine_with_restarts": |
| 618 | args.lr_scheduler, | 633 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 619 | optimizer=optimizer, | 634 | optimizer=optimizer, |
| 620 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 635 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 621 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 636 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 622 | ) | 637 | num_cycles=args.lr_cycles, |
| 638 | ) | ||
| 639 | else: | ||
| 640 | lr_scheduler = get_scheduler( | ||
| 641 | args.lr_scheduler, | ||
| 642 | optimizer=optimizer, | ||
| 643 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
| 644 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 645 | ) | ||
| 623 | 646 | ||
| 624 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 647 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 625 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 648 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
