diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 09:43:22 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 09:43:22 +0100 |
commit | 5571c4ebcb39813e2bd8585de30c64bb02f9d7fa (patch) | |
tree | a073f625eaa49c3cd908aacb3debae23e5badbf7 /train_dreambooth.py | |
parent | Cleanup (diff) | |
download | textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.gz textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.bz2 textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.zip |
Improved aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 100 |
1 files changed, 45 insertions, 55 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 589af59..42a7d0f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -134,12 +134,6 @@ def parse_args(): | |||
134 | help="The directory where class images will be saved.", | 134 | help="The directory where class images will be saved.", |
135 | ) | 135 | ) |
136 | parser.add_argument( | 136 | parser.add_argument( |
137 | "--repeats", | ||
138 | type=int, | ||
139 | default=1, | ||
140 | help="How many times to repeat the training data." | ||
141 | ) | ||
142 | parser.add_argument( | ||
143 | "--output_dir", | 137 | "--output_dir", |
144 | type=str, | 138 | type=str, |
145 | default="output/dreambooth", | 139 | default="output/dreambooth", |
@@ -738,7 +732,6 @@ def main(): | |||
738 | class_subdir=args.class_image_dir, | 732 | class_subdir=args.class_image_dir, |
739 | num_class_images=args.num_class_images, | 733 | num_class_images=args.num_class_images, |
740 | size=args.resolution, | 734 | size=args.resolution, |
741 | repeats=args.repeats, | ||
742 | dropout=args.tag_dropout, | 735 | dropout=args.tag_dropout, |
743 | template_key=args.train_data_template, | 736 | template_key=args.train_data_template, |
744 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
@@ -751,7 +744,7 @@ def main(): | |||
751 | datamodule.prepare_data() | 744 | datamodule.prepare_data() |
752 | datamodule.setup() | 745 | datamodule.setup() |
753 | 746 | ||
754 | train_dataloaders = datamodule.train_dataloaders | 747 | train_dataloader = datamodule.train_dataloader |
755 | val_dataloader = datamodule.val_dataloader | 748 | val_dataloader = datamodule.val_dataloader |
756 | 749 | ||
757 | if args.num_class_images != 0: | 750 | if args.num_class_images != 0: |
@@ -770,8 +763,7 @@ def main(): | |||
770 | 763 | ||
771 | # Scheduler and math around the number of training steps. | 764 | # Scheduler and math around the number of training steps. |
772 | overrode_max_train_steps = False | 765 | overrode_max_train_steps = False |
773 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 766 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
774 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
775 | if args.max_train_steps is None: | 767 | if args.max_train_steps is None: |
776 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 768 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
777 | overrode_max_train_steps = True | 769 | overrode_max_train_steps = True |
@@ -820,8 +812,7 @@ def main(): | |||
820 | ema_unet.to(accelerator.device) | 812 | ema_unet.to(accelerator.device) |
821 | 813 | ||
822 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 814 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
823 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 815 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
824 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
825 | if overrode_max_train_steps: | 816 | if overrode_max_train_steps: |
826 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
827 | 818 | ||
@@ -877,7 +868,7 @@ def main(): | |||
877 | accelerator, | 868 | accelerator, |
878 | text_encoder, | 869 | text_encoder, |
879 | optimizer, | 870 | optimizer, |
880 | train_dataloaders[0], | 871 | train_dataloader, |
881 | val_dataloader, | 872 | val_dataloader, |
882 | loop, | 873 | loop, |
883 | on_train=tokenizer.train, | 874 | on_train=tokenizer.train, |
@@ -960,54 +951,53 @@ def main(): | |||
960 | text_encoder.requires_grad_(False) | 951 | text_encoder.requires_grad_(False) |
961 | 952 | ||
962 | with on_train(): | 953 | with on_train(): |
963 | for train_dataloader in train_dataloaders: | 954 | for step, batch in enumerate(train_dataloader): |
964 | for step, batch in enumerate(train_dataloader): | 955 | with accelerator.accumulate(unet): |
965 | with accelerator.accumulate(unet): | 956 | loss, acc, bsz = loop(step, batch) |
966 | loss, acc, bsz = loop(step, batch) | ||
967 | |||
968 | accelerator.backward(loss) | ||
969 | |||
970 | if accelerator.sync_gradients: | ||
971 | params_to_clip = ( | ||
972 | itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
973 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | ||
974 | else unet.parameters() | ||
975 | ) | ||
976 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
977 | |||
978 | optimizer.step() | ||
979 | if not accelerator.optimizer_step_was_skipped: | ||
980 | lr_scheduler.step() | ||
981 | if args.use_ema: | ||
982 | ema_unet.step(unet.parameters()) | ||
983 | optimizer.zero_grad(set_to_none=True) | ||
984 | |||
985 | avg_loss.update(loss.detach_(), bsz) | ||
986 | avg_acc.update(acc.detach_(), bsz) | ||
987 | |||
988 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
989 | if accelerator.sync_gradients: | ||
990 | local_progress_bar.update(1) | ||
991 | global_progress_bar.update(1) | ||
992 | 957 | ||
993 | global_step += 1 | 958 | accelerator.backward(loss) |
994 | 959 | ||
995 | logs = { | 960 | if accelerator.sync_gradients: |
996 | "train/loss": avg_loss.avg.item(), | 961 | params_to_clip = ( |
997 | "train/acc": avg_acc.avg.item(), | 962 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
998 | "train/cur_loss": loss.item(), | 963 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs |
999 | "train/cur_acc": acc.item(), | 964 | else unet.parameters() |
1000 | "lr": lr_scheduler.get_last_lr()[0] | 965 | ) |
1001 | } | 966 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
967 | |||
968 | optimizer.step() | ||
969 | if not accelerator.optimizer_step_was_skipped: | ||
970 | lr_scheduler.step() | ||
1002 | if args.use_ema: | 971 | if args.use_ema: |
1003 | logs["ema_decay"] = 1 - ema_unet.decay | 972 | ema_unet.step(unet.parameters()) |
973 | optimizer.zero_grad(set_to_none=True) | ||
1004 | 974 | ||
1005 | accelerator.log(logs, step=global_step) | 975 | avg_loss.update(loss.detach_(), bsz) |
976 | avg_acc.update(acc.detach_(), bsz) | ||
1006 | 977 | ||
1007 | local_progress_bar.set_postfix(**logs) | 978 | # Checks if the accelerator has performed an optimization step behind the scenes |
979 | if accelerator.sync_gradients: | ||
980 | local_progress_bar.update(1) | ||
981 | global_progress_bar.update(1) | ||
982 | |||
983 | global_step += 1 | ||
984 | |||
985 | logs = { | ||
986 | "train/loss": avg_loss.avg.item(), | ||
987 | "train/acc": avg_acc.avg.item(), | ||
988 | "train/cur_loss": loss.item(), | ||
989 | "train/cur_acc": acc.item(), | ||
990 | "lr": lr_scheduler.get_last_lr()[0] | ||
991 | } | ||
992 | if args.use_ema: | ||
993 | logs["ema_decay"] = 1 - ema_unet.decay | ||
994 | |||
995 | accelerator.log(logs, step=global_step) | ||
996 | |||
997 | local_progress_bar.set_postfix(**logs) | ||
1008 | 998 | ||
1009 | if global_step >= args.max_train_steps: | 999 | if global_step >= args.max_train_steps: |
1010 | break | 1000 | break |
1011 | 1001 | ||
1012 | accelerator.wait_for_everyone() | 1002 | accelerator.wait_for_everyone() |
1013 | 1003 | ||