summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-07 08:34:07 +0200
committerVolpeon <git@volpeon.ink>2022-10-07 08:34:07 +0200
commit14beba63391e1ddc9a145bb638d9306086ad1a5c (patch)
tree4e7d5126359c4ab6ab6dff3c2af537d659e276e8 /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.gz
textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.bz2
textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.zip
Training: Create multiple class images per training image
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py48
1 files changed, 35 insertions, 13 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 11c324d..86fcdfe 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -68,16 +68,17 @@ def parse_args():
68 help="A token to use as initializer word." 68 help="A token to use as initializer word."
69 ) 69 )
70 parser.add_argument( 70 parser.add_argument(
71 "--use_class_images", 71 "--num_class_images",
72 action="store_true", 72 type=int,
73 default=True, 73 default=2,
74 help="Include class images in the loss calculation a la Dreambooth.", 74 help="How many class images to generate per training image."
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--repeats", 77 "--repeats",
78 type=int, 78 type=int,
79 default=100, 79 default=100,
80 help="How many times to repeat the training data.") 80 help="How many times to repeat the training data."
81 )
81 parser.add_argument( 82 parser.add_argument(
82 "--output_dir", 83 "--output_dir",
83 type=str, 84 type=str,
@@ -204,6 +205,12 @@ def parse_args():
204 help="How often to save a checkpoint and sample image", 205 help="How often to save a checkpoint and sample image",
205 ) 206 )
206 parser.add_argument( 207 parser.add_argument(
208 "--sample_frequency",
209 type=int,
210 default=100,
211 help="How often to save a checkpoint and sample image",
212 )
213 parser.add_argument(
207 "--sample_image_size", 214 "--sample_image_size",
208 type=int, 215 type=int,
209 default=512, 216 default=512,
@@ -381,6 +388,7 @@ class Checkpointer:
381 scheduler=scheduler, 388 scheduler=scheduler,
382 ).to(self.accelerator.device) 389 ).to(self.accelerator.device)
383 pipeline.enable_attention_slicing() 390 pipeline.enable_attention_slicing()
391 pipeline.set_progress_bar_config(dynamic_ncols=True)
384 392
385 train_data = self.datamodule.train_dataloader() 393 train_data = self.datamodule.train_dataloader()
386 val_data = self.datamodule.val_dataloader() 394 val_data = self.datamodule.val_dataloader()
@@ -577,7 +585,7 @@ def main():
577 pixel_values = [example["instance_images"] for example in examples] 585 pixel_values = [example["instance_images"] for example in examples]
578 586
579 # concat class and instance examples for prior preservation 587 # concat class and instance examples for prior preservation
580 if args.use_class_images and "class_prompt_ids" in examples[0]: 588 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
581 input_ids += [example["class_prompt_ids"] for example in examples] 589 input_ids += [example["class_prompt_ids"] for example in examples]
582 pixel_values += [example["class_images"] for example in examples] 590 pixel_values += [example["class_images"] for example in examples]
583 591
@@ -599,8 +607,9 @@ def main():
599 batch_size=args.train_batch_size, 607 batch_size=args.train_batch_size,
600 tokenizer=tokenizer, 608 tokenizer=tokenizer,
601 instance_identifier=args.placeholder_token, 609 instance_identifier=args.placeholder_token,
602 class_identifier=args.initializer_token if args.use_class_images else None, 610 class_identifier=args.initializer_token,
603 class_subdir="ti_cls", 611 class_subdir="ti_cls",
612 num_class_images=args.num_class_images,
604 size=args.resolution, 613 size=args.resolution,
605 repeats=args.repeats, 614 repeats=args.repeats,
606 center_crop=args.center_crop, 615 center_crop=args.center_crop,
@@ -611,7 +620,7 @@ def main():
611 datamodule.prepare_data() 620 datamodule.prepare_data()
612 datamodule.setup() 621 datamodule.setup()
613 622
614 if args.use_class_images: 623 if args.num_class_images != 0:
615 missing_data = [item for item in datamodule.data if not item[1].exists()] 624 missing_data = [item for item in datamodule.data if not item[1].exists()]
616 625
617 if len(missing_data) != 0: 626 if len(missing_data) != 0:
@@ -630,6 +639,7 @@ def main():
630 scheduler=scheduler, 639 scheduler=scheduler,
631 ).to(accelerator.device) 640 ).to(accelerator.device)
632 pipeline.enable_attention_slicing() 641 pipeline.enable_attention_slicing()
642 pipeline.set_progress_bar_config(dynamic_ncols=True)
633 643
634 for batch in batched_data: 644 for batch in batched_data:
635 image_name = [p[1] for p in batch] 645 image_name = [p[1] for p in batch]
@@ -729,11 +739,18 @@ def main():
729 text_encoder, 739 text_encoder,
730 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 740 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
731 741
732 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), 742 local_progress_bar = tqdm(
733 disable=not accelerator.is_local_main_process) 743 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
744 disable=not accelerator.is_local_main_process,
745 dynamic_ncols=True
746 )
734 local_progress_bar.set_description("Batch X out of Y") 747 local_progress_bar.set_description("Batch X out of Y")
735 748
736 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) 749 global_progress_bar = tqdm(
750 range(args.max_train_steps + val_steps),
751 disable=not accelerator.is_local_main_process,
752 dynamic_ncols=True
753 )
737 global_progress_bar.set_description("Total progress") 754 global_progress_bar.set_description("Total progress")
738 755
739 try: 756 try:
@@ -744,6 +761,8 @@ def main():
744 text_encoder.train() 761 text_encoder.train()
745 train_loss = 0.0 762 train_loss = 0.0
746 763
764 sample_checkpoint = False
765
747 for step, batch in enumerate(train_dataloader): 766 for step, batch in enumerate(train_dataloader):
748 with accelerator.accumulate(text_encoder): 767 with accelerator.accumulate(text_encoder):
749 # Convert images to latent space 768 # Convert images to latent space
@@ -769,7 +788,7 @@ def main():
769 # Predict the noise residual 788 # Predict the noise residual
770 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 789 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
771 790
772 if args.use_class_images: 791 if args.num_class_images != 0:
773 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 792 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
774 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 793 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
775 noise, noise_prior = torch.chunk(noise, 2, dim=0) 794 noise, noise_prior = torch.chunk(noise, 2, dim=0)
@@ -812,6 +831,9 @@ def main():
812 831
813 global_step += 1 832 global_step += 1
814 833
834 if global_step % args.sample_frequency == 0:
835 sample_checkpoint = True
836
815 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 837 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
816 local_progress_bar.clear() 838 local_progress_bar.clear()
817 global_progress_bar.clear() 839 global_progress_bar.clear()
@@ -878,7 +900,7 @@ def main():
878 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) 900 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
879 min_val_loss = val_loss 901 min_val_loss = val_loss
880 902
881 if accelerator.is_main_process: 903 if sample_checkpoint and accelerator.is_main_process:
882 checkpointer.save_samples( 904 checkpointer.save_samples(
883 global_step + global_step_offset, 905 global_step + global_step_offset,
884 text_encoder, 906 text_encoder,