diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 09:43:22 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 09:43:22 +0100 |
commit | 5571c4ebcb39813e2bd8585de30c64bb02f9d7fa (patch) | |
tree | a073f625eaa49c3cd908aacb3debae23e5badbf7 /train_dreambooth.py | |
parent | Cleanup (diff) | |
download | textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.gz textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.bz2 textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.zip |
Improved aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 92 |
1 files changed, 41 insertions, 51 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 589af59..42a7d0f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -134,12 +134,6 @@ def parse_args(): | |||
134 | help="The directory where class images will be saved.", | 134 | help="The directory where class images will be saved.", |
135 | ) | 135 | ) |
136 | parser.add_argument( | 136 | parser.add_argument( |
137 | "--repeats", | ||
138 | type=int, | ||
139 | default=1, | ||
140 | help="How many times to repeat the training data." | ||
141 | ) | ||
142 | parser.add_argument( | ||
143 | "--output_dir", | 137 | "--output_dir", |
144 | type=str, | 138 | type=str, |
145 | default="output/dreambooth", | 139 | default="output/dreambooth", |
@@ -738,7 +732,6 @@ def main(): | |||
738 | class_subdir=args.class_image_dir, | 732 | class_subdir=args.class_image_dir, |
739 | num_class_images=args.num_class_images, | 733 | num_class_images=args.num_class_images, |
740 | size=args.resolution, | 734 | size=args.resolution, |
741 | repeats=args.repeats, | ||
742 | dropout=args.tag_dropout, | 735 | dropout=args.tag_dropout, |
743 | template_key=args.train_data_template, | 736 | template_key=args.train_data_template, |
744 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
@@ -751,7 +744,7 @@ def main(): | |||
751 | datamodule.prepare_data() | 744 | datamodule.prepare_data() |
752 | datamodule.setup() | 745 | datamodule.setup() |
753 | 746 | ||
754 | train_dataloaders = datamodule.train_dataloaders | 747 | train_dataloader = datamodule.train_dataloader |
755 | val_dataloader = datamodule.val_dataloader | 748 | val_dataloader = datamodule.val_dataloader |
756 | 749 | ||
757 | if args.num_class_images != 0: | 750 | if args.num_class_images != 0: |
@@ -770,8 +763,7 @@ def main(): | |||
770 | 763 | ||
771 | # Scheduler and math around the number of training steps. | 764 | # Scheduler and math around the number of training steps. |
772 | overrode_max_train_steps = False | 765 | overrode_max_train_steps = False |
773 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 766 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
774 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
775 | if args.max_train_steps is None: | 767 | if args.max_train_steps is None: |
776 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 768 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
777 | overrode_max_train_steps = True | 769 | overrode_max_train_steps = True |
@@ -820,8 +812,7 @@ def main(): | |||
820 | ema_unet.to(accelerator.device) | 812 | ema_unet.to(accelerator.device) |
821 | 813 | ||
822 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 814 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
823 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 815 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
824 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
825 | if overrode_max_train_steps: | 816 | if overrode_max_train_steps: |
826 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
827 | 818 | ||
@@ -877,7 +868,7 @@ def main(): | |||
877 | accelerator, | 868 | accelerator, |
878 | text_encoder, | 869 | text_encoder, |
879 | optimizer, | 870 | optimizer, |
880 | train_dataloaders[0], | 871 | train_dataloader, |
881 | val_dataloader, | 872 | val_dataloader, |
882 | loop, | 873 | loop, |
883 | on_train=tokenizer.train, | 874 | on_train=tokenizer.train, |
@@ -960,54 +951,53 @@ def main(): | |||
960 | text_encoder.requires_grad_(False) | 951 | text_encoder.requires_grad_(False) |
961 | 952 | ||
962 | with on_train(): | 953 | with on_train(): |
963 | for train_dataloader in train_dataloaders: | 954 | for step, batch in enumerate(train_dataloader): |
964 | for step, batch in enumerate(train_dataloader): | 955 | with accelerator.accumulate(unet): |
965 | with accelerator.accumulate(unet): | 956 | loss, acc, bsz = loop(step, batch) |
966 | loss, acc, bsz = loop(step, batch) | ||
967 | 957 | ||
968 | accelerator.backward(loss) | 958 | accelerator.backward(loss) |
969 | 959 | ||
970 | if accelerator.sync_gradients: | 960 | if accelerator.sync_gradients: |
971 | params_to_clip = ( | 961 | params_to_clip = ( |
972 | itertools.chain(unet.parameters(), text_encoder.parameters()) | 962 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
973 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | 963 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs |
974 | else unet.parameters() | 964 | else unet.parameters() |
975 | ) | 965 | ) |
976 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | 966 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
977 | 967 | ||
978 | optimizer.step() | 968 | optimizer.step() |
979 | if not accelerator.optimizer_step_was_skipped: | 969 | if not accelerator.optimizer_step_was_skipped: |
980 | lr_scheduler.step() | 970 | lr_scheduler.step() |
981 | if args.use_ema: | 971 | if args.use_ema: |
982 | ema_unet.step(unet.parameters()) | 972 | ema_unet.step(unet.parameters()) |
983 | optimizer.zero_grad(set_to_none=True) | 973 | optimizer.zero_grad(set_to_none=True) |
984 | 974 | ||
985 | avg_loss.update(loss.detach_(), bsz) | 975 | avg_loss.update(loss.detach_(), bsz) |
986 | avg_acc.update(acc.detach_(), bsz) | 976 | avg_acc.update(acc.detach_(), bsz) |
987 | 977 | ||
988 | # Checks if the accelerator has performed an optimization step behind the scenes | 978 | # Checks if the accelerator has performed an optimization step behind the scenes |
989 | if accelerator.sync_gradients: | 979 | if accelerator.sync_gradients: |
990 | local_progress_bar.update(1) | 980 | local_progress_bar.update(1) |
991 | global_progress_bar.update(1) | 981 | global_progress_bar.update(1) |
992 | 982 | ||
993 | global_step += 1 | 983 | global_step += 1 |
994 | 984 | ||
995 | logs = { | 985 | logs = { |
996 | "train/loss": avg_loss.avg.item(), | 986 | "train/loss": avg_loss.avg.item(), |
997 | "train/acc": avg_acc.avg.item(), | 987 | "train/acc": avg_acc.avg.item(), |
998 | "train/cur_loss": loss.item(), | 988 | "train/cur_loss": loss.item(), |
999 | "train/cur_acc": acc.item(), | 989 | "train/cur_acc": acc.item(), |
1000 | "lr": lr_scheduler.get_last_lr()[0] | 990 | "lr": lr_scheduler.get_last_lr()[0] |
1001 | } | 991 | } |
1002 | if args.use_ema: | 992 | if args.use_ema: |
1003 | logs["ema_decay"] = 1 - ema_unet.decay | 993 | logs["ema_decay"] = 1 - ema_unet.decay |
1004 | 994 | ||
1005 | accelerator.log(logs, step=global_step) | 995 | accelerator.log(logs, step=global_step) |
1006 | 996 | ||
1007 | local_progress_bar.set_postfix(**logs) | 997 | local_progress_bar.set_postfix(**logs) |
1008 | 998 | ||
1009 | if global_step >= args.max_train_steps: | 999 | if global_step >= args.max_train_steps: |
1010 | break | 1000 | break |
1011 | 1001 | ||
1012 | accelerator.wait_for_everyone() | 1002 | accelerator.wait_for_everyone() |
1013 | 1003 | ||