diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 09:43:22 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 09:43:22 +0100 |
commit | 5571c4ebcb39813e2bd8585de30c64bb02f9d7fa (patch) | |
tree | a073f625eaa49c3cd908aacb3debae23e5badbf7 /train_ti.py | |
parent | Cleanup (diff) | |
download | textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.gz textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.bz2 textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.zip |
Improved aspect ratio bucketing
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 85 |
1 files changed, 37 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py index b4b602b..727b591 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -107,12 +107,6 @@ def parse_args(): | |||
107 | help="Exclude all items with a listed collection.", | 107 | help="Exclude all items with a listed collection.", |
108 | ) | 108 | ) |
109 | parser.add_argument( | 109 | parser.add_argument( |
110 | "--repeats", | ||
111 | type=int, | ||
112 | default=1, | ||
113 | help="How many times to repeat the training data." | ||
114 | ) | ||
115 | parser.add_argument( | ||
116 | "--output_dir", | 110 | "--output_dir", |
117 | type=str, | 111 | type=str, |
118 | default="output/text-inversion", | 112 | default="output/text-inversion", |
@@ -722,7 +716,6 @@ def main(): | |||
722 | size=args.resolution, | 716 | size=args.resolution, |
723 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 717 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, |
724 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 718 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, |
725 | repeats=args.repeats, | ||
726 | dropout=args.tag_dropout, | 719 | dropout=args.tag_dropout, |
727 | template_key=args.train_data_template, | 720 | template_key=args.train_data_template, |
728 | valid_set_size=args.valid_set_size, | 721 | valid_set_size=args.valid_set_size, |
@@ -733,7 +726,7 @@ def main(): | |||
733 | ) | 726 | ) |
734 | datamodule.setup() | 727 | datamodule.setup() |
735 | 728 | ||
736 | train_dataloaders = datamodule.train_dataloaders | 729 | train_dataloader = datamodule.train_dataloader |
737 | val_dataloader = datamodule.val_dataloader | 730 | val_dataloader = datamodule.val_dataloader |
738 | 731 | ||
739 | if args.num_class_images != 0: | 732 | if args.num_class_images != 0: |
@@ -752,8 +745,7 @@ def main(): | |||
752 | 745 | ||
753 | # Scheduler and math around the number of training steps. | 746 | # Scheduler and math around the number of training steps. |
754 | overrode_max_train_steps = False | 747 | overrode_max_train_steps = False |
755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 748 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
756 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
757 | if args.max_train_steps is None: | 749 | if args.max_train_steps is None: |
758 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 750 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
759 | overrode_max_train_steps = True | 751 | overrode_max_train_steps = True |
@@ -790,10 +782,9 @@ def main(): | |||
790 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 782 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
791 | ) | 783 | ) |
792 | 784 | ||
793 | text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( | 785 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
794 | text_encoder, optimizer, val_dataloader, lr_scheduler | 786 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
795 | ) | 787 | ) |
796 | train_dataloaders = accelerator.prepare(*train_dataloaders) | ||
797 | 788 | ||
798 | # Move vae and unet to device | 789 | # Move vae and unet to device |
799 | vae.to(accelerator.device, dtype=weight_dtype) | 790 | vae.to(accelerator.device, dtype=weight_dtype) |
@@ -811,8 +802,7 @@ def main(): | |||
811 | unet.eval() | 802 | unet.eval() |
812 | 803 | ||
813 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 804 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
814 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 805 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
815 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
816 | if overrode_max_train_steps: | 806 | if overrode_max_train_steps: |
817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 807 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
818 | 808 | ||
@@ -870,7 +860,7 @@ def main(): | |||
870 | accelerator, | 860 | accelerator, |
871 | text_encoder, | 861 | text_encoder, |
872 | optimizer, | 862 | optimizer, |
873 | train_dataloaders[0], | 863 | train_dataloader, |
874 | val_dataloader, | 864 | val_dataloader, |
875 | loop, | 865 | loop, |
876 | on_train=on_train, | 866 | on_train=on_train, |
@@ -949,48 +939,47 @@ def main(): | |||
949 | text_encoder.train() | 939 | text_encoder.train() |
950 | 940 | ||
951 | with on_train(): | 941 | with on_train(): |
952 | for train_dataloader in train_dataloaders: | 942 | for step, batch in enumerate(train_dataloader): |
953 | for step, batch in enumerate(train_dataloader): | 943 | with accelerator.accumulate(text_encoder): |
954 | with accelerator.accumulate(text_encoder): | 944 | loss, acc, bsz = loop(step, batch) |
955 | loss, acc, bsz = loop(step, batch) | ||
956 | 945 | ||
957 | accelerator.backward(loss) | 946 | accelerator.backward(loss) |
958 | 947 | ||
959 | optimizer.step() | 948 | optimizer.step() |
960 | if not accelerator.optimizer_step_was_skipped: | 949 | if not accelerator.optimizer_step_was_skipped: |
961 | lr_scheduler.step() | 950 | lr_scheduler.step() |
962 | optimizer.zero_grad(set_to_none=True) | 951 | optimizer.zero_grad(set_to_none=True) |
963 | 952 | ||
964 | avg_loss.update(loss.detach_(), bsz) | 953 | avg_loss.update(loss.detach_(), bsz) |
965 | avg_acc.update(acc.detach_(), bsz) | 954 | avg_acc.update(acc.detach_(), bsz) |
966 | 955 | ||
967 | # Checks if the accelerator has performed an optimization step behind the scenes | 956 | # Checks if the accelerator has performed an optimization step behind the scenes |
968 | if accelerator.sync_gradients: | 957 | if accelerator.sync_gradients: |
969 | if args.use_ema: | 958 | if args.use_ema: |
970 | ema_embeddings.step( | 959 | ema_embeddings.step( |
971 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 960 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
972 | 961 | ||
973 | local_progress_bar.update(1) | 962 | local_progress_bar.update(1) |
974 | global_progress_bar.update(1) | 963 | global_progress_bar.update(1) |
975 | 964 | ||
976 | global_step += 1 | 965 | global_step += 1 |
977 | 966 | ||
978 | logs = { | 967 | logs = { |
979 | "train/loss": avg_loss.avg.item(), | 968 | "train/loss": avg_loss.avg.item(), |
980 | "train/acc": avg_acc.avg.item(), | 969 | "train/acc": avg_acc.avg.item(), |
981 | "train/cur_loss": loss.item(), | 970 | "train/cur_loss": loss.item(), |
982 | "train/cur_acc": acc.item(), | 971 | "train/cur_acc": acc.item(), |
983 | "lr": lr_scheduler.get_last_lr()[0], | 972 | "lr": lr_scheduler.get_last_lr()[0], |
984 | } | 973 | } |
985 | if args.use_ema: | 974 | if args.use_ema: |
986 | logs["ema_decay"] = ema_embeddings.decay | 975 | logs["ema_decay"] = ema_embeddings.decay |
987 | 976 | ||
988 | accelerator.log(logs, step=global_step) | 977 | accelerator.log(logs, step=global_step) |
989 | 978 | ||
990 | local_progress_bar.set_postfix(**logs) | 979 | local_progress_bar.set_postfix(**logs) |
991 | 980 | ||
992 | if global_step >= args.max_train_steps: | 981 | if global_step >= args.max_train_steps: |
993 | break | 982 | break |
994 | 983 | ||
995 | accelerator.wait_for_everyone() | 984 | accelerator.wait_for_everyone() |
996 | 985 | ||