diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 48 |
1 files changed, 35 insertions, 13 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 11c324d..86fcdfe 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -68,16 +68,17 @@ def parse_args(): | |||
| 68 | help="A token to use as initializer word." | 68 | help="A token to use as initializer word." |
| 69 | ) | 69 | ) |
| 70 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--use_class_images", | 71 | "--num_class_images", |
| 72 | action="store_true", | 72 | type=int, |
| 73 | default=True, | 73 | default=2, |
| 74 | help="Include class images in the loss calculation a la Dreambooth.", | 74 | help="How many class images to generate per training image." |
| 75 | ) | 75 | ) |
| 76 | parser.add_argument( | 76 | parser.add_argument( |
| 77 | "--repeats", | 77 | "--repeats", |
| 78 | type=int, | 78 | type=int, |
| 79 | default=100, | 79 | default=100, |
| 80 | help="How many times to repeat the training data.") | 80 | help="How many times to repeat the training data." |
| 81 | ) | ||
| 81 | parser.add_argument( | 82 | parser.add_argument( |
| 82 | "--output_dir", | 83 | "--output_dir", |
| 83 | type=str, | 84 | type=str, |
| @@ -204,6 +205,12 @@ def parse_args(): | |||
| 204 | help="How often to save a checkpoint and sample image", | 205 | help="How often to save a checkpoint and sample image", |
| 205 | ) | 206 | ) |
| 206 | parser.add_argument( | 207 | parser.add_argument( |
| 208 | "--sample_frequency", | ||
| 209 | type=int, | ||
| 210 | default=100, | ||
| 211 | help="How often to save a checkpoint and sample image", | ||
| 212 | ) | ||
| 213 | parser.add_argument( | ||
| 207 | "--sample_image_size", | 214 | "--sample_image_size", |
| 208 | type=int, | 215 | type=int, |
| 209 | default=512, | 216 | default=512, |
| @@ -381,6 +388,7 @@ class Checkpointer: | |||
| 381 | scheduler=scheduler, | 388 | scheduler=scheduler, |
| 382 | ).to(self.accelerator.device) | 389 | ).to(self.accelerator.device) |
| 383 | pipeline.enable_attention_slicing() | 390 | pipeline.enable_attention_slicing() |
| 391 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 384 | 392 | ||
| 385 | train_data = self.datamodule.train_dataloader() | 393 | train_data = self.datamodule.train_dataloader() |
| 386 | val_data = self.datamodule.val_dataloader() | 394 | val_data = self.datamodule.val_dataloader() |
| @@ -577,7 +585,7 @@ def main(): | |||
| 577 | pixel_values = [example["instance_images"] for example in examples] | 585 | pixel_values = [example["instance_images"] for example in examples] |
| 578 | 586 | ||
| 579 | # concat class and instance examples for prior preservation | 587 | # concat class and instance examples for prior preservation |
| 580 | if args.use_class_images and "class_prompt_ids" in examples[0]: | 588 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: |
| 581 | input_ids += [example["class_prompt_ids"] for example in examples] | 589 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 582 | pixel_values += [example["class_images"] for example in examples] | 590 | pixel_values += [example["class_images"] for example in examples] |
| 583 | 591 | ||
| @@ -599,8 +607,9 @@ def main(): | |||
| 599 | batch_size=args.train_batch_size, | 607 | batch_size=args.train_batch_size, |
| 600 | tokenizer=tokenizer, | 608 | tokenizer=tokenizer, |
| 601 | instance_identifier=args.placeholder_token, | 609 | instance_identifier=args.placeholder_token, |
| 602 | class_identifier=args.initializer_token if args.use_class_images else None, | 610 | class_identifier=args.initializer_token, |
| 603 | class_subdir="ti_cls", | 611 | class_subdir="ti_cls", |
| 612 | num_class_images=args.num_class_images, | ||
| 604 | size=args.resolution, | 613 | size=args.resolution, |
| 605 | repeats=args.repeats, | 614 | repeats=args.repeats, |
| 606 | center_crop=args.center_crop, | 615 | center_crop=args.center_crop, |
| @@ -611,7 +620,7 @@ def main(): | |||
| 611 | datamodule.prepare_data() | 620 | datamodule.prepare_data() |
| 612 | datamodule.setup() | 621 | datamodule.setup() |
| 613 | 622 | ||
| 614 | if args.use_class_images: | 623 | if args.num_class_images != 0: |
| 615 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 624 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
| 616 | 625 | ||
| 617 | if len(missing_data) != 0: | 626 | if len(missing_data) != 0: |
| @@ -630,6 +639,7 @@ def main(): | |||
| 630 | scheduler=scheduler, | 639 | scheduler=scheduler, |
| 631 | ).to(accelerator.device) | 640 | ).to(accelerator.device) |
| 632 | pipeline.enable_attention_slicing() | 641 | pipeline.enable_attention_slicing() |
| 642 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 633 | 643 | ||
| 634 | for batch in batched_data: | 644 | for batch in batched_data: |
| 635 | image_name = [p[1] for p in batch] | 645 | image_name = [p[1] for p in batch] |
| @@ -729,11 +739,18 @@ def main(): | |||
| 729 | text_encoder, | 739 | text_encoder, |
| 730 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 740 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 731 | 741 | ||
| 732 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 742 | local_progress_bar = tqdm( |
| 733 | disable=not accelerator.is_local_main_process) | 743 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| 744 | disable=not accelerator.is_local_main_process, | ||
| 745 | dynamic_ncols=True | ||
| 746 | ) | ||
| 734 | local_progress_bar.set_description("Batch X out of Y") | 747 | local_progress_bar.set_description("Batch X out of Y") |
| 735 | 748 | ||
| 736 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) | 749 | global_progress_bar = tqdm( |
| 750 | range(args.max_train_steps + val_steps), | ||
| 751 | disable=not accelerator.is_local_main_process, | ||
| 752 | dynamic_ncols=True | ||
| 753 | ) | ||
| 737 | global_progress_bar.set_description("Total progress") | 754 | global_progress_bar.set_description("Total progress") |
| 738 | 755 | ||
| 739 | try: | 756 | try: |
| @@ -744,6 +761,8 @@ def main(): | |||
| 744 | text_encoder.train() | 761 | text_encoder.train() |
| 745 | train_loss = 0.0 | 762 | train_loss = 0.0 |
| 746 | 763 | ||
| 764 | sample_checkpoint = False | ||
| 765 | |||
| 747 | for step, batch in enumerate(train_dataloader): | 766 | for step, batch in enumerate(train_dataloader): |
| 748 | with accelerator.accumulate(text_encoder): | 767 | with accelerator.accumulate(text_encoder): |
| 749 | # Convert images to latent space | 768 | # Convert images to latent space |
| @@ -769,7 +788,7 @@ def main(): | |||
| 769 | # Predict the noise residual | 788 | # Predict the noise residual |
| 770 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 789 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 771 | 790 | ||
| 772 | if args.use_class_images: | 791 | if args.num_class_images != 0: |
| 773 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 792 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 774 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 793 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 775 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 794 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| @@ -812,6 +831,9 @@ def main(): | |||
| 812 | 831 | ||
| 813 | global_step += 1 | 832 | global_step += 1 |
| 814 | 833 | ||
| 834 | if global_step % args.sample_frequency == 0: | ||
| 835 | sample_checkpoint = True | ||
| 836 | |||
| 815 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 837 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
| 816 | local_progress_bar.clear() | 838 | local_progress_bar.clear() |
| 817 | global_progress_bar.clear() | 839 | global_progress_bar.clear() |
| @@ -878,7 +900,7 @@ def main(): | |||
| 878 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | 900 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) |
| 879 | min_val_loss = val_loss | 901 | min_val_loss = val_loss |
| 880 | 902 | ||
| 881 | if accelerator.is_main_process: | 903 | if sample_checkpoint and accelerator.is_main_process: |
| 882 | checkpointer.save_samples( | 904 | checkpointer.save_samples( |
| 883 | global_step + global_step_offset, | 905 | global_step + global_step_offset, |
| 884 | text_encoder, | 906 | text_encoder, |
