From f5b656d21c5b449eed6ce212e909043c124f79ee Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 12 Oct 2022 08:18:22 +0200 Subject: Various updates --- textual_inversion.py | 74 ++++++++++------------------------------------------ 1 file changed, 14 insertions(+), 60 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index e6d856a..3a3741d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -17,7 +17,6 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -112,7 +111,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=5000, + default=3000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -150,30 +149,9 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=600, + default=500, help="Number of steps for the warmup in the lr scheduler." ) - parser.add_argument( - "--use_ema", - action="store_true", - default=True, - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_power", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 - ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -348,7 +326,6 @@ class Checkpointer: unet, tokenizer, text_encoder, - ema_text_encoder, placeholder_token, placeholder_token_id, output_dir: Path, @@ -363,7 +340,6 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder - self.ema_text_encoder = ema_text_encoder self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -380,8 +356,7 @@ class Checkpointer: checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - unwrapped = self.accelerator.unwrap_model( - self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) + unwrapped = self.accelerator.unwrap_model(self.text_encoder) # Save a checkpoint learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] @@ -400,8 +375,7 @@ class Checkpointer: def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model( - self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) + unwrapped = self.accelerator.unwrap_model(self.text_encoder) scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -507,9 +481,7 @@ def main(): if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path + '/tokenizer' - ) + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) @@ -530,15 +502,10 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path + '/text_encoder', - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path + '/vae', - ) + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path + '/unet', - ) + args.pretrained_model_name_or_path, subfolder='unet') if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -707,13 +674,6 @@ def main(): text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - ema_text_encoder = EMAModel( - text_encoder, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay - ) if args.use_ema else None - # Move vae and unet to device vae.to(accelerator.device) unet.to(accelerator.device) @@ -757,7 +717,6 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, - ema_text_encoder=ema_text_encoder, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, @@ -777,7 +736,7 @@ def main(): disable=not accelerator.is_local_main_process, dynamic_ncols=True ) - local_progress_bar.set_description("Batch X out of Y") + local_progress_bar.set_description("Epoch X / Y") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), @@ -788,7 +747,7 @@ def main(): try: for epoch in range(num_epochs): - local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() text_encoder.train() @@ -799,9 +758,8 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) @@ -859,9 +817,6 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - if args.use_ema: - ema_text_encoder.step(unet) - local_progress_bar.update(1) global_progress_bar.update(1) @@ -881,8 +836,6 @@ def main(): }) logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} - if args.use_ema: - logs["ema_decay"] = ema_text_encoder.decay accelerator.log(logs, step=global_step) @@ -937,7 +890,8 @@ def main(): global_progress_bar.clear() if min_val_loss > val_loss: - accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") checkpointer.checkpoint(global_step + global_step_offset, "milestone") min_val_loss = val_loss -- cgit v1.2.3-54-g00ecf