diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 37 |
1 files changed, 23 insertions, 14 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index da7c747..a9c3326 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -107,7 +107,7 @@ def parse_args(): | |||
107 | parser.add_argument( | 107 | parser.add_argument( |
108 | "--resolution", | 108 | "--resolution", |
109 | type=int, | 109 | type=int, |
110 | default=512, | 110 | default=768, |
111 | help=( | 111 | help=( |
112 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | 112 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
113 | " resolution" | 113 | " resolution" |
@@ -119,6 +119,12 @@ def parse_args(): | |||
119 | help="Whether to center crop images before resizing to resolution" | 119 | help="Whether to center crop images before resizing to resolution" |
120 | ) | 120 | ) |
121 | parser.add_argument( | 121 | parser.add_argument( |
122 | "--tag_dropout", | ||
123 | type=float, | ||
124 | default=0.1, | ||
125 | help="Tag dropout probability.", | ||
126 | ) | ||
127 | parser.add_argument( | ||
122 | "--dataloader_num_workers", | 128 | "--dataloader_num_workers", |
123 | type=int, | 129 | type=int, |
124 | default=0, | 130 | default=0, |
@@ -171,9 +177,9 @@ def parse_args(): | |||
171 | ), | 177 | ), |
172 | ) | 178 | ) |
173 | parser.add_argument( | 179 | parser.add_argument( |
174 | "--lr_warmup_steps", | 180 | "--lr_warmup_epochs", |
175 | type=int, | 181 | type=int, |
176 | default=300, | 182 | default=10, |
177 | help="Number of steps for the warmup in the lr scheduler." | 183 | help="Number of steps for the warmup in the lr scheduler." |
178 | ) | 184 | ) |
179 | parser.add_argument( | 185 | parser.add_argument( |
@@ -237,7 +243,7 @@ def parse_args(): | |||
237 | parser.add_argument( | 243 | parser.add_argument( |
238 | "--sample_image_size", | 244 | "--sample_image_size", |
239 | type=int, | 245 | type=int, |
240 | default=512, | 246 | default=768, |
241 | help="Size of sample images", | 247 | help="Size of sample images", |
242 | ) | 248 | ) |
243 | parser.add_argument( | 249 | parser.add_argument( |
@@ -267,7 +273,7 @@ def parse_args(): | |||
267 | parser.add_argument( | 273 | parser.add_argument( |
268 | "--sample_steps", | 274 | "--sample_steps", |
269 | type=int, | 275 | type=int, |
270 | default=30, | 276 | default=15, |
271 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 277 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
272 | ) | 278 | ) |
273 | parser.add_argument( | 279 | parser.add_argument( |
@@ -399,28 +405,28 @@ class Checkpointer: | |||
399 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 405 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
400 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 406 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
401 | 407 | ||
402 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 408 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
403 | 409 | ||
404 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 410 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
405 | # Save a checkpoint | 411 | # Save a checkpoint |
406 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] | 412 | learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] |
407 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 413 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
408 | 414 | ||
409 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 415 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
410 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 416 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
411 | 417 | ||
412 | del unwrapped | 418 | del text_encoder |
413 | del learned_embeds | 419 | del learned_embeds |
414 | 420 | ||
415 | @torch.no_grad() | 421 | @torch.no_grad() |
416 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 422 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
417 | samples_path = Path(self.output_dir).joinpath("samples") | 423 | samples_path = Path(self.output_dir).joinpath("samples") |
418 | 424 | ||
419 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 425 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
420 | 426 | ||
421 | # Save a sample image | 427 | # Save a sample image |
422 | pipeline = VlpnStableDiffusion( | 428 | pipeline = VlpnStableDiffusion( |
423 | text_encoder=unwrapped, | 429 | text_encoder=text_encoder, |
424 | vae=self.vae, | 430 | vae=self.vae, |
425 | unet=self.unet, | 431 | unet=self.unet, |
426 | tokenizer=self.tokenizer, | 432 | tokenizer=self.tokenizer, |
@@ -471,7 +477,7 @@ class Checkpointer: | |||
471 | negative_prompt=nprompt, | 477 | negative_prompt=nprompt, |
472 | height=self.sample_image_size, | 478 | height=self.sample_image_size, |
473 | width=self.sample_image_size, | 479 | width=self.sample_image_size, |
474 | latents_or_image=latents[:len(prompt)] if latents is not None else None, | 480 | image=latents[:len(prompt)] if latents is not None else None, |
475 | generator=generator if latents is not None else None, | 481 | generator=generator if latents is not None else None, |
476 | guidance_scale=guidance_scale, | 482 | guidance_scale=guidance_scale, |
477 | eta=eta, | 483 | eta=eta, |
@@ -489,7 +495,7 @@ class Checkpointer: | |||
489 | del all_samples | 495 | del all_samples |
490 | del image_grid | 496 | del image_grid |
491 | 497 | ||
492 | del unwrapped | 498 | del text_encoder |
493 | del pipeline | 499 | del pipeline |
494 | del generator | 500 | del generator |
495 | del stable_latents | 501 | del stable_latents |
@@ -662,6 +668,7 @@ def main(): | |||
662 | num_class_images=args.num_class_images, | 668 | num_class_images=args.num_class_images, |
663 | size=args.resolution, | 669 | size=args.resolution, |
664 | repeats=args.repeats, | 670 | repeats=args.repeats, |
671 | dropout=args.tag_dropout, | ||
665 | center_crop=args.center_crop, | 672 | center_crop=args.center_crop, |
666 | valid_set_size=args.valid_set_size, | 673 | valid_set_size=args.valid_set_size, |
667 | num_workers=args.dataloader_num_workers, | 674 | num_workers=args.dataloader_num_workers, |
@@ -720,6 +727,8 @@ def main(): | |||
720 | overrode_max_train_steps = True | 727 | overrode_max_train_steps = True |
721 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 728 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
722 | 729 | ||
730 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | ||
731 | |||
723 | if args.lr_scheduler == "one_cycle": | 732 | if args.lr_scheduler == "one_cycle": |
724 | lr_scheduler = get_one_cycle_schedule( | 733 | lr_scheduler = get_one_cycle_schedule( |
725 | optimizer=optimizer, | 734 | optimizer=optimizer, |
@@ -728,7 +737,7 @@ def main(): | |||
728 | elif args.lr_scheduler == "cosine_with_restarts": | 737 | elif args.lr_scheduler == "cosine_with_restarts": |
729 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 738 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
730 | optimizer=optimizer, | 739 | optimizer=optimizer, |
731 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 740 | num_warmup_steps=warmup_steps, |
732 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
733 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | 742 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
734 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), | 743 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
@@ -737,7 +746,7 @@ def main(): | |||
737 | lr_scheduler = get_scheduler( | 746 | lr_scheduler = get_scheduler( |
738 | args.lr_scheduler, | 747 | args.lr_scheduler, |
739 | optimizer=optimizer, | 748 | optimizer=optimizer, |
740 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 749 | num_warmup_steps=warmup_steps, |
741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 750 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
742 | ) | 751 | ) |
743 | 752 | ||