From f5b656d21c5b449eed6ce212e909043c124f79ee Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 12 Oct 2022 08:18:22 +0200 Subject: Various updates --- dreambooth.py | 53 +++++++++++++++++++++++------------------------------ 1 file changed, 23 insertions(+), 30 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 02f83c6..775aea2 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=5000, + default=3000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -150,7 +150,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=600, + default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -167,7 +167,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=1.0 + default=7 / 8 ) parser.add_argument( "--ema_max_decay", @@ -468,20 +468,20 @@ def main(): if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path + '/tokenizer' - ) + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path + '/text_encoder', - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path + '/vae', - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path + '/unet', - ) + args.pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + + ema_unet = EMAModel( + unet, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay + ) if args.use_ema else None if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -538,7 +538,7 @@ def main(): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -629,16 +629,10 @@ def main(): unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - ema_unet = EMAModel( - unet, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay - ) if args.use_ema else None - # Move text_encoder and vae to device text_encoder.to(accelerator.device) vae.to(accelerator.device) + ema_unet.averaged_model.to(accelerator.device) # Keep text_encoder and vae in eval mode as we don't train these text_encoder.eval() @@ -698,7 +692,7 @@ def main(): disable=not accelerator.is_local_main_process, dynamic_ncols=True ) - local_progress_bar.set_description("Batch X out of Y") + local_progress_bar.set_description("Epoch X / Y") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), @@ -709,7 +703,7 @@ def main(): try: for epoch in range(num_epochs): - local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() unet.train() @@ -720,9 +714,8 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) @@ -737,8 +730,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -840,7 +832,8 @@ def main(): global_progress_bar.clear() if min_val_loss > val_loss: - accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") min_val_loss = val_loss if sample_checkpoint and accelerator.is_main_process: -- cgit v1.2.3-54-g00ecf