diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 48 |
1 files changed, 35 insertions, 13 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 11c324d..86fcdfe 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -68,16 +68,17 @@ def parse_args(): | |||
68 | help="A token to use as initializer word." | 68 | help="A token to use as initializer word." |
69 | ) | 69 | ) |
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--use_class_images", | 71 | "--num_class_images", |
72 | action="store_true", | 72 | type=int, |
73 | default=True, | 73 | default=2, |
74 | help="Include class images in the loss calculation a la Dreambooth.", | 74 | help="How many class images to generate per training image." |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--repeats", | 77 | "--repeats", |
78 | type=int, | 78 | type=int, |
79 | default=100, | 79 | default=100, |
80 | help="How many times to repeat the training data.") | 80 | help="How many times to repeat the training data." |
81 | ) | ||
81 | parser.add_argument( | 82 | parser.add_argument( |
82 | "--output_dir", | 83 | "--output_dir", |
83 | type=str, | 84 | type=str, |
@@ -204,6 +205,12 @@ def parse_args(): | |||
204 | help="How often to save a checkpoint and sample image", | 205 | help="How often to save a checkpoint and sample image", |
205 | ) | 206 | ) |
206 | parser.add_argument( | 207 | parser.add_argument( |
208 | "--sample_frequency", | ||
209 | type=int, | ||
210 | default=100, | ||
211 | help="How often to save a checkpoint and sample image", | ||
212 | ) | ||
213 | parser.add_argument( | ||
207 | "--sample_image_size", | 214 | "--sample_image_size", |
208 | type=int, | 215 | type=int, |
209 | default=512, | 216 | default=512, |
@@ -381,6 +388,7 @@ class Checkpointer: | |||
381 | scheduler=scheduler, | 388 | scheduler=scheduler, |
382 | ).to(self.accelerator.device) | 389 | ).to(self.accelerator.device) |
383 | pipeline.enable_attention_slicing() | 390 | pipeline.enable_attention_slicing() |
391 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
384 | 392 | ||
385 | train_data = self.datamodule.train_dataloader() | 393 | train_data = self.datamodule.train_dataloader() |
386 | val_data = self.datamodule.val_dataloader() | 394 | val_data = self.datamodule.val_dataloader() |
@@ -577,7 +585,7 @@ def main(): | |||
577 | pixel_values = [example["instance_images"] for example in examples] | 585 | pixel_values = [example["instance_images"] for example in examples] |
578 | 586 | ||
579 | # concat class and instance examples for prior preservation | 587 | # concat class and instance examples for prior preservation |
580 | if args.use_class_images and "class_prompt_ids" in examples[0]: | 588 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: |
581 | input_ids += [example["class_prompt_ids"] for example in examples] | 589 | input_ids += [example["class_prompt_ids"] for example in examples] |
582 | pixel_values += [example["class_images"] for example in examples] | 590 | pixel_values += [example["class_images"] for example in examples] |
583 | 591 | ||
@@ -599,8 +607,9 @@ def main(): | |||
599 | batch_size=args.train_batch_size, | 607 | batch_size=args.train_batch_size, |
600 | tokenizer=tokenizer, | 608 | tokenizer=tokenizer, |
601 | instance_identifier=args.placeholder_token, | 609 | instance_identifier=args.placeholder_token, |
602 | class_identifier=args.initializer_token if args.use_class_images else None, | 610 | class_identifier=args.initializer_token, |
603 | class_subdir="ti_cls", | 611 | class_subdir="ti_cls", |
612 | num_class_images=args.num_class_images, | ||
604 | size=args.resolution, | 613 | size=args.resolution, |
605 | repeats=args.repeats, | 614 | repeats=args.repeats, |
606 | center_crop=args.center_crop, | 615 | center_crop=args.center_crop, |
@@ -611,7 +620,7 @@ def main(): | |||
611 | datamodule.prepare_data() | 620 | datamodule.prepare_data() |
612 | datamodule.setup() | 621 | datamodule.setup() |
613 | 622 | ||
614 | if args.use_class_images: | 623 | if args.num_class_images != 0: |
615 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 624 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
616 | 625 | ||
617 | if len(missing_data) != 0: | 626 | if len(missing_data) != 0: |
@@ -630,6 +639,7 @@ def main(): | |||
630 | scheduler=scheduler, | 639 | scheduler=scheduler, |
631 | ).to(accelerator.device) | 640 | ).to(accelerator.device) |
632 | pipeline.enable_attention_slicing() | 641 | pipeline.enable_attention_slicing() |
642 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
633 | 643 | ||
634 | for batch in batched_data: | 644 | for batch in batched_data: |
635 | image_name = [p[1] for p in batch] | 645 | image_name = [p[1] for p in batch] |
@@ -729,11 +739,18 @@ def main(): | |||
729 | text_encoder, | 739 | text_encoder, |
730 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 740 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
731 | 741 | ||
732 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 742 | local_progress_bar = tqdm( |
733 | disable=not accelerator.is_local_main_process) | 743 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
744 | disable=not accelerator.is_local_main_process, | ||
745 | dynamic_ncols=True | ||
746 | ) | ||
734 | local_progress_bar.set_description("Batch X out of Y") | 747 | local_progress_bar.set_description("Batch X out of Y") |
735 | 748 | ||
736 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) | 749 | global_progress_bar = tqdm( |
750 | range(args.max_train_steps + val_steps), | ||
751 | disable=not accelerator.is_local_main_process, | ||
752 | dynamic_ncols=True | ||
753 | ) | ||
737 | global_progress_bar.set_description("Total progress") | 754 | global_progress_bar.set_description("Total progress") |
738 | 755 | ||
739 | try: | 756 | try: |
@@ -744,6 +761,8 @@ def main(): | |||
744 | text_encoder.train() | 761 | text_encoder.train() |
745 | train_loss = 0.0 | 762 | train_loss = 0.0 |
746 | 763 | ||
764 | sample_checkpoint = False | ||
765 | |||
747 | for step, batch in enumerate(train_dataloader): | 766 | for step, batch in enumerate(train_dataloader): |
748 | with accelerator.accumulate(text_encoder): | 767 | with accelerator.accumulate(text_encoder): |
749 | # Convert images to latent space | 768 | # Convert images to latent space |
@@ -769,7 +788,7 @@ def main(): | |||
769 | # Predict the noise residual | 788 | # Predict the noise residual |
770 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 789 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
771 | 790 | ||
772 | if args.use_class_images: | 791 | if args.num_class_images != 0: |
773 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 792 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
774 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 793 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
775 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 794 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
@@ -812,6 +831,9 @@ def main(): | |||
812 | 831 | ||
813 | global_step += 1 | 832 | global_step += 1 |
814 | 833 | ||
834 | if global_step % args.sample_frequency == 0: | ||
835 | sample_checkpoint = True | ||
836 | |||
815 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 837 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
816 | local_progress_bar.clear() | 838 | local_progress_bar.clear() |
817 | global_progress_bar.clear() | 839 | global_progress_bar.clear() |
@@ -878,7 +900,7 @@ def main(): | |||
878 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | 900 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) |
879 | min_val_loss = val_loss | 901 | min_val_loss = val_loss |
880 | 902 | ||
881 | if accelerator.is_main_process: | 903 | if sample_checkpoint and accelerator.is_main_process: |
882 | checkpointer.save_samples( | 904 | checkpointer.save_samples( |
883 | global_step + global_step_offset, | 905 | global_step + global_step_offset, |
884 | text_encoder, | 906 | text_encoder, |