summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-16 14:39:39 +0200
committerVolpeon <git@volpeon.ink>2022-10-16 14:39:39 +0200
commitdee4c7135754543f1eb7ea616ee3847d34a85b51 (patch)
tree4064b44bb79e499cf6a8f1ec38a83a4889f067a7 /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.gz
textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.bz2
textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py41
1 files changed, 32 insertions, 9 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 1ba8dc0..9e2645b 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -15,7 +15,7 @@ from accelerate import Accelerator
15from accelerate.logging import get_logger 15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
18from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
19from diffusers.training_utils import EMAModel 19from diffusers.training_utils import EMAModel
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
@@ -150,10 +150,16 @@ def parse_args():
150 parser.add_argument( 150 parser.add_argument(
151 "--lr_warmup_steps", 151 "--lr_warmup_steps",
152 type=int, 152 type=int,
153 default=500, 153 default=300,
154 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
155 ) 155 )
156 parser.add_argument( 156 parser.add_argument(
157 "--lr_cycles",
158 type=int,
159 default=2,
160 help="Number of restart cycles in the lr scheduler."
161 )
162 parser.add_argument(
157 "--use_ema", 163 "--use_ema",
158 action="store_true", 164 action="store_true",
159 default=True, 165 default=True,
@@ -167,7 +173,7 @@ def parse_args():
167 parser.add_argument( 173 parser.add_argument(
168 "--ema_power", 174 "--ema_power",
169 type=float, 175 type=float,
170 default=6 / 7 176 default=9 / 10
171 ) 177 )
172 parser.add_argument( 178 parser.add_argument(
173 "--ema_max_decay", 179 "--ema_max_decay",
@@ -296,6 +302,13 @@ def parse_args():
296 return args 302 return args
297 303
298 304
305def save_args(basepath: Path, args, extra={}):
306 info = {"args": vars(args)}
307 info["args"].update(extra)
308 with open(basepath.joinpath("args.json"), "w") as f:
309 json.dump(info, f, indent=4)
310
311
299def freeze_params(params): 312def freeze_params(params):
300 for param in params: 313 for param in params:
301 param.requires_grad = False 314 param.requires_grad = False
@@ -455,6 +468,8 @@ def main():
455 468
456 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 469 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
457 470
471 save_args(basepath, args)
472
458 # If passed along, set the training seed now. 473 # If passed along, set the training seed now.
459 if args.seed is not None: 474 if args.seed is not None:
460 set_seed(args.seed) 475 set_seed(args.seed)
@@ -614,12 +629,20 @@ def main():
614 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 629 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
615 overrode_max_train_steps = True 630 overrode_max_train_steps = True
616 631
617 lr_scheduler = get_scheduler( 632 if args.lr_scheduler == "cosine_with_restarts":
618 args.lr_scheduler, 633 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
619 optimizer=optimizer, 634 optimizer=optimizer,
620 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 635 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
621 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 636 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
622 ) 637 num_cycles=args.lr_cycles,
638 )
639 else:
640 lr_scheduler = get_scheduler(
641 args.lr_scheduler,
642 optimizer=optimizer,
643 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
644 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
645 )
623 646
624 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 647 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
625 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 648 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler