diff options
author | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
commit | 3ee13893f9a4973ac75f45fe9318c35760dd4b1f (patch) | |
tree | e652a54e6c241eef52ddb30f2d7048da8f306f7b /train_dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.gz textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.bz2 textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.zip |
Added progressive aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 12 |
1 files changed, 3 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index e8256be..d265bcc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -22,7 +22,7 @@ from slugify import slugify | |||
22 | 22 | ||
23 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
26 | from training.common import run_model | 26 | from training.common import run_model |
27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
28 | from training.lr import LRFinder | 28 | from training.lr import LRFinder |
@@ -172,11 +172,6 @@ def parse_args(): | |||
172 | ), | 172 | ), |
173 | ) | 173 | ) |
174 | parser.add_argument( | 174 | parser.add_argument( |
175 | "--center_crop", | ||
176 | action="store_true", | ||
177 | help="Whether to center crop images before resizing to resolution" | ||
178 | ) | ||
179 | parser.add_argument( | ||
180 | "--dataloader_num_workers", | 175 | "--dataloader_num_workers", |
181 | type=int, | 176 | type=int, |
182 | default=0, | 177 | default=0, |
@@ -698,7 +693,7 @@ def main(): | |||
698 | elif args.mixed_precision == "bf16": | 693 | elif args.mixed_precision == "bf16": |
699 | weight_dtype = torch.bfloat16 | 694 | weight_dtype = torch.bfloat16 |
700 | 695 | ||
701 | def keyword_filter(item: CSVDataItem): | 696 | def keyword_filter(item: VlpnDataItem): |
702 | cond3 = args.collection is None or args.collection in item.collection | 697 | cond3 = args.collection is None or args.collection in item.collection |
703 | cond4 = args.exclude_collections is None or not any( | 698 | cond4 = args.exclude_collections is None or not any( |
704 | collection in item.collection | 699 | collection in item.collection |
@@ -733,7 +728,7 @@ def main(): | |||
733 | } | 728 | } |
734 | return batch | 729 | return batch |
735 | 730 | ||
736 | datamodule = CSVDataModule( | 731 | datamodule = VlpnDataModule( |
737 | data_file=args.train_data_file, | 732 | data_file=args.train_data_file, |
738 | batch_size=args.train_batch_size, | 733 | batch_size=args.train_batch_size, |
739 | prompt_processor=prompt_processor, | 734 | prompt_processor=prompt_processor, |
@@ -742,7 +737,6 @@ def main(): | |||
742 | size=args.resolution, | 737 | size=args.resolution, |
743 | repeats=args.repeats, | 738 | repeats=args.repeats, |
744 | dropout=args.tag_dropout, | 739 | dropout=args.tag_dropout, |
745 | center_crop=args.center_crop, | ||
746 | template_key=args.train_data_template, | 740 | template_key=args.train_data_template, |
747 | valid_set_size=args.valid_set_size, | 741 | valid_set_size=args.valid_set_size, |
748 | num_workers=args.dataloader_num_workers, | 742 | num_workers=args.dataloader_num_workers, |