diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 85 |
1 files changed, 37 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py index b4b602b..727b591 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -107,12 +107,6 @@ def parse_args(): | |||
| 107 | help="Exclude all items with a listed collection.", | 107 | help="Exclude all items with a listed collection.", |
| 108 | ) | 108 | ) |
| 109 | parser.add_argument( | 109 | parser.add_argument( |
| 110 | "--repeats", | ||
| 111 | type=int, | ||
| 112 | default=1, | ||
| 113 | help="How many times to repeat the training data." | ||
| 114 | ) | ||
| 115 | parser.add_argument( | ||
| 116 | "--output_dir", | 110 | "--output_dir", |
| 117 | type=str, | 111 | type=str, |
| 118 | default="output/text-inversion", | 112 | default="output/text-inversion", |
| @@ -722,7 +716,6 @@ def main(): | |||
| 722 | size=args.resolution, | 716 | size=args.resolution, |
| 723 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 717 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, |
| 724 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 718 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, |
| 725 | repeats=args.repeats, | ||
| 726 | dropout=args.tag_dropout, | 719 | dropout=args.tag_dropout, |
| 727 | template_key=args.train_data_template, | 720 | template_key=args.train_data_template, |
| 728 | valid_set_size=args.valid_set_size, | 721 | valid_set_size=args.valid_set_size, |
| @@ -733,7 +726,7 @@ def main(): | |||
| 733 | ) | 726 | ) |
| 734 | datamodule.setup() | 727 | datamodule.setup() |
| 735 | 728 | ||
| 736 | train_dataloaders = datamodule.train_dataloaders | 729 | train_dataloader = datamodule.train_dataloader |
| 737 | val_dataloader = datamodule.val_dataloader | 730 | val_dataloader = datamodule.val_dataloader |
| 738 | 731 | ||
| 739 | if args.num_class_images != 0: | 732 | if args.num_class_images != 0: |
| @@ -752,8 +745,7 @@ def main(): | |||
| 752 | 745 | ||
| 753 | # Scheduler and math around the number of training steps. | 746 | # Scheduler and math around the number of training steps. |
| 754 | overrode_max_train_steps = False | 747 | overrode_max_train_steps = False |
| 755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 748 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 756 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 757 | if args.max_train_steps is None: | 749 | if args.max_train_steps is None: |
| 758 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 750 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 759 | overrode_max_train_steps = True | 751 | overrode_max_train_steps = True |
| @@ -790,10 +782,9 @@ def main(): | |||
| 790 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 782 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 791 | ) | 783 | ) |
| 792 | 784 | ||
| 793 | text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( | 785 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 794 | text_encoder, optimizer, val_dataloader, lr_scheduler | 786 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 795 | ) | 787 | ) |
| 796 | train_dataloaders = accelerator.prepare(*train_dataloaders) | ||
| 797 | 788 | ||
| 798 | # Move vae and unet to device | 789 | # Move vae and unet to device |
| 799 | vae.to(accelerator.device, dtype=weight_dtype) | 790 | vae.to(accelerator.device, dtype=weight_dtype) |
| @@ -811,8 +802,7 @@ def main(): | |||
| 811 | unet.eval() | 802 | unet.eval() |
| 812 | 803 | ||
| 813 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 804 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| 814 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 805 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 815 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 816 | if overrode_max_train_steps: | 806 | if overrode_max_train_steps: |
| 817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 807 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 818 | 808 | ||
| @@ -870,7 +860,7 @@ def main(): | |||
| 870 | accelerator, | 860 | accelerator, |
| 871 | text_encoder, | 861 | text_encoder, |
| 872 | optimizer, | 862 | optimizer, |
| 873 | train_dataloaders[0], | 863 | train_dataloader, |
| 874 | val_dataloader, | 864 | val_dataloader, |
| 875 | loop, | 865 | loop, |
| 876 | on_train=on_train, | 866 | on_train=on_train, |
| @@ -949,48 +939,47 @@ def main(): | |||
| 949 | text_encoder.train() | 939 | text_encoder.train() |
| 950 | 940 | ||
| 951 | with on_train(): | 941 | with on_train(): |
| 952 | for train_dataloader in train_dataloaders: | 942 | for step, batch in enumerate(train_dataloader): |
| 953 | for step, batch in enumerate(train_dataloader): | 943 | with accelerator.accumulate(text_encoder): |
| 954 | with accelerator.accumulate(text_encoder): | 944 | loss, acc, bsz = loop(step, batch) |
| 955 | loss, acc, bsz = loop(step, batch) | ||
| 956 | 945 | ||
| 957 | accelerator.backward(loss) | 946 | accelerator.backward(loss) |
| 958 | 947 | ||
| 959 | optimizer.step() | 948 | optimizer.step() |
| 960 | if not accelerator.optimizer_step_was_skipped: | 949 | if not accelerator.optimizer_step_was_skipped: |
| 961 | lr_scheduler.step() | 950 | lr_scheduler.step() |
| 962 | optimizer.zero_grad(set_to_none=True) | 951 | optimizer.zero_grad(set_to_none=True) |
| 963 | 952 | ||
| 964 | avg_loss.update(loss.detach_(), bsz) | 953 | avg_loss.update(loss.detach_(), bsz) |
| 965 | avg_acc.update(acc.detach_(), bsz) | 954 | avg_acc.update(acc.detach_(), bsz) |
| 966 | 955 | ||
| 967 | # Checks if the accelerator has performed an optimization step behind the scenes | 956 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 968 | if accelerator.sync_gradients: | 957 | if accelerator.sync_gradients: |
| 969 | if args.use_ema: | 958 | if args.use_ema: |
| 970 | ema_embeddings.step( | 959 | ema_embeddings.step( |
| 971 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 960 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
| 972 | 961 | ||
| 973 | local_progress_bar.update(1) | 962 | local_progress_bar.update(1) |
| 974 | global_progress_bar.update(1) | 963 | global_progress_bar.update(1) |
| 975 | 964 | ||
| 976 | global_step += 1 | 965 | global_step += 1 |
| 977 | 966 | ||
| 978 | logs = { | 967 | logs = { |
| 979 | "train/loss": avg_loss.avg.item(), | 968 | "train/loss": avg_loss.avg.item(), |
| 980 | "train/acc": avg_acc.avg.item(), | 969 | "train/acc": avg_acc.avg.item(), |
| 981 | "train/cur_loss": loss.item(), | 970 | "train/cur_loss": loss.item(), |
| 982 | "train/cur_acc": acc.item(), | 971 | "train/cur_acc": acc.item(), |
| 983 | "lr": lr_scheduler.get_last_lr()[0], | 972 | "lr": lr_scheduler.get_last_lr()[0], |
| 984 | } | 973 | } |
| 985 | if args.use_ema: | 974 | if args.use_ema: |
| 986 | logs["ema_decay"] = ema_embeddings.decay | 975 | logs["ema_decay"] = ema_embeddings.decay |
| 987 | 976 | ||
| 988 | accelerator.log(logs, step=global_step) | 977 | accelerator.log(logs, step=global_step) |
| 989 | 978 | ||
| 990 | local_progress_bar.set_postfix(**logs) | 979 | local_progress_bar.set_postfix(**logs) |
| 991 | 980 | ||
| 992 | if global_step >= args.max_train_steps: | 981 | if global_step >= args.max_train_steps: |
| 993 | break | 982 | break |
| 994 | 983 | ||
| 995 | accelerator.wait_for_everyone() | 984 | accelerator.wait_for_everyone() |
| 996 | 985 | ||
