summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-07 13:57:46 +0100
committerVolpeon <git@volpeon.ink>2023-01-07 13:57:46 +0100
commit3ee13893f9a4973ac75f45fe9318c35760dd4b1f (patch)
treee652a54e6c241eef52ddb30f2d7048da8f306f7b /train_dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.gz
textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.bz2
textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.zip
Added progressive aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py12
1 files changed, 3 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e8256be..d265bcc 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -22,7 +22,7 @@ from slugify import slugify
22 22
23from util import load_config, load_embeddings_from_dir 23from util import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import VlpnDataModule, VlpnDataItem
26from training.common import run_model 26from training.common import run_model
27from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
28from training.lr import LRFinder 28from training.lr import LRFinder
@@ -172,11 +172,6 @@ def parse_args():
172 ), 172 ),
173 ) 173 )
174 parser.add_argument( 174 parser.add_argument(
175 "--center_crop",
176 action="store_true",
177 help="Whether to center crop images before resizing to resolution"
178 )
179 parser.add_argument(
180 "--dataloader_num_workers", 175 "--dataloader_num_workers",
181 type=int, 176 type=int,
182 default=0, 177 default=0,
@@ -698,7 +693,7 @@ def main():
698 elif args.mixed_precision == "bf16": 693 elif args.mixed_precision == "bf16":
699 weight_dtype = torch.bfloat16 694 weight_dtype = torch.bfloat16
700 695
701 def keyword_filter(item: CSVDataItem): 696 def keyword_filter(item: VlpnDataItem):
702 cond3 = args.collection is None or args.collection in item.collection 697 cond3 = args.collection is None or args.collection in item.collection
703 cond4 = args.exclude_collections is None or not any( 698 cond4 = args.exclude_collections is None or not any(
704 collection in item.collection 699 collection in item.collection
@@ -733,7 +728,7 @@ def main():
733 } 728 }
734 return batch 729 return batch
735 730
736 datamodule = CSVDataModule( 731 datamodule = VlpnDataModule(
737 data_file=args.train_data_file, 732 data_file=args.train_data_file,
738 batch_size=args.train_batch_size, 733 batch_size=args.train_batch_size,
739 prompt_processor=prompt_processor, 734 prompt_processor=prompt_processor,
@@ -742,7 +737,6 @@ def main():
742 size=args.resolution, 737 size=args.resolution,
743 repeats=args.repeats, 738 repeats=args.repeats,
744 dropout=args.tag_dropout, 739 dropout=args.tag_dropout,
745 center_crop=args.center_crop,
746 template_key=args.train_data_template, 740 template_key=args.train_data_template,
747 valid_set_size=args.valid_set_size, 741 valid_set_size=args.valid_set_size,
748 num_workers=args.dataloader_num_workers, 742 num_workers=args.dataloader_num_workers,