From dd02ace41f69541044e9db106feaa76bf02da8f6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 12 Dec 2022 08:05:06 +0100 Subject: Dreambooth: Support loading Textual Inversion embeddings --- textual_inversion.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index da7c747..a9c3326 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -107,7 +107,7 @@ def parse_args(): parser.add_argument( "--resolution", type=int, - default=512, + default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" @@ -118,6 +118,12 @@ def parse_args(): action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0.1, + help="Tag dropout probability.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -171,9 +177,9 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", + "--lr_warmup_epochs", type=int, - default=300, + default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -237,7 +243,7 @@ def parse_args(): parser.add_argument( "--sample_image_size", type=int, - default=512, + default=768, help="Size of sample images", ) parser.add_argument( @@ -267,7 +273,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=30, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -399,28 +405,28 @@ class Checkpointer: checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - unwrapped = self.accelerator.unwrap_model(self.text_encoder) + text_encoder = self.accelerator.unwrap_model(self.text_encoder) for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] + learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) - del unwrapped + del text_encoder del learned_embeds @torch.no_grad() def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model(self.text_encoder) + text_encoder = self.accelerator.unwrap_model(self.text_encoder) # Save a sample image pipeline = VlpnStableDiffusion( - text_encoder=unwrapped, + text_encoder=text_encoder, vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, @@ -471,7 +477,7 @@ class Checkpointer: negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, - latents_or_image=latents[:len(prompt)] if latents is not None else None, + image=latents[:len(prompt)] if latents is not None else None, generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, @@ -489,7 +495,7 @@ class Checkpointer: del all_samples del image_grid - del unwrapped + del text_encoder del pipeline del generator del stable_latents @@ -662,6 +668,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + dropout=args.tag_dropout, center_crop=args.center_crop, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, @@ -720,6 +727,8 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps + if args.lr_scheduler == "one_cycle": lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, @@ -728,7 +737,7 @@ def main(): elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_cycles or math.ceil(math.sqrt( ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), @@ -737,7 +746,7 @@ def main(): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) -- cgit v1.2.3-54-g00ecf