From db0996c299fdd559ebf9cd48f9dbe47474ed7b07 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Oct 2022 09:45:27 +0200 Subject: Added TI+Dreambooth training --- dreambooth.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 775aea2..699313e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -129,7 +129,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-6, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -167,7 +167,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=7 / 8 + default=6 / 7 ) parser.add_argument( "--ema_max_decay", @@ -269,6 +269,11 @@ def parse_args(): type=float, help="Max gradient norm." ) + parser.add_argument( + "--noise_timesteps", + type=int, + default=1000, + ) parser.add_argument( "--config", type=str, @@ -480,7 +485,8 @@ def main(): unet, inv_gamma=args.ema_inv_gamma, power=args.ema_power, - max_value=args.ema_max_decay + max_value=args.ema_max_decay, + device=accelerator.device ) if args.use_ema else None if args.gradient_checkpointing: @@ -523,7 +529,7 @@ def main(): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000 + num_train_timesteps=args.noise_timesteps ) def collate_fn(examples): @@ -632,7 +638,6 @@ def main(): # Move text_encoder and vae to device text_encoder.to(accelerator.device) vae.to(accelerator.device) - ema_unet.averaged_model.to(accelerator.device) # Keep text_encoder and vae in eval mode as we don't train these text_encoder.eval() -- cgit v1.2.3-54-g00ecf