diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
| commit | 3ee13893f9a4973ac75f45fe9318c35760dd4b1f (patch) | |
| tree | e652a54e6c241eef52ddb30f2d7048da8f306f7b /train_dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.gz textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.bz2 textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.zip | |
Added progressive aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 12 |
1 files changed, 3 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index e8256be..d265bcc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -22,7 +22,7 @@ from slugify import slugify | |||
| 22 | 22 | ||
| 23 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
| 26 | from training.common import run_model | 26 | from training.common import run_model |
| 27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 28 | from training.lr import LRFinder | 28 | from training.lr import LRFinder |
| @@ -172,11 +172,6 @@ def parse_args(): | |||
| 172 | ), | 172 | ), |
| 173 | ) | 173 | ) |
| 174 | parser.add_argument( | 174 | parser.add_argument( |
| 175 | "--center_crop", | ||
| 176 | action="store_true", | ||
| 177 | help="Whether to center crop images before resizing to resolution" | ||
| 178 | ) | ||
| 179 | parser.add_argument( | ||
| 180 | "--dataloader_num_workers", | 175 | "--dataloader_num_workers", |
| 181 | type=int, | 176 | type=int, |
| 182 | default=0, | 177 | default=0, |
| @@ -698,7 +693,7 @@ def main(): | |||
| 698 | elif args.mixed_precision == "bf16": | 693 | elif args.mixed_precision == "bf16": |
| 699 | weight_dtype = torch.bfloat16 | 694 | weight_dtype = torch.bfloat16 |
| 700 | 695 | ||
| 701 | def keyword_filter(item: CSVDataItem): | 696 | def keyword_filter(item: VlpnDataItem): |
| 702 | cond3 = args.collection is None or args.collection in item.collection | 697 | cond3 = args.collection is None or args.collection in item.collection |
| 703 | cond4 = args.exclude_collections is None or not any( | 698 | cond4 = args.exclude_collections is None or not any( |
| 704 | collection in item.collection | 699 | collection in item.collection |
| @@ -733,7 +728,7 @@ def main(): | |||
| 733 | } | 728 | } |
| 734 | return batch | 729 | return batch |
| 735 | 730 | ||
| 736 | datamodule = CSVDataModule( | 731 | datamodule = VlpnDataModule( |
| 737 | data_file=args.train_data_file, | 732 | data_file=args.train_data_file, |
| 738 | batch_size=args.train_batch_size, | 733 | batch_size=args.train_batch_size, |
| 739 | prompt_processor=prompt_processor, | 734 | prompt_processor=prompt_processor, |
| @@ -742,7 +737,6 @@ def main(): | |||
| 742 | size=args.resolution, | 737 | size=args.resolution, |
| 743 | repeats=args.repeats, | 738 | repeats=args.repeats, |
| 744 | dropout=args.tag_dropout, | 739 | dropout=args.tag_dropout, |
| 745 | center_crop=args.center_crop, | ||
| 746 | template_key=args.train_data_template, | 740 | template_key=args.train_data_template, |
| 747 | valid_set_size=args.valid_set_size, | 741 | valid_set_size=args.valid_set_size, |
| 748 | num_workers=args.dataloader_num_workers, | 742 | num_workers=args.dataloader_num_workers, |
