From 3ee13893f9a4973ac75f45fe9318c35760dd4b1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 13:57:46 +0100 Subject: Added progressive aspect ratio bucketing --- train_dreambooth.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index e8256be..d265bcc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -22,7 +22,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import CSVDataModule, CSVDataItem +from data.csv import VlpnDataModule, VlpnDataItem from training.common import run_model from training.optimization import get_one_cycle_schedule from training.lr import LRFinder @@ -171,11 +171,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution" - ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -698,7 +693,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - def keyword_filter(item: CSVDataItem): + def keyword_filter(item: VlpnDataItem): cond3 = args.collection is None or args.collection in item.collection cond4 = args.exclude_collections is None or not any( collection in item.collection @@ -733,7 +728,7 @@ def main(): } return batch - datamodule = CSVDataModule( + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, @@ -742,7 +737,6 @@ def main(): size=args.resolution, repeats=args.repeats, dropout=args.tag_dropout, - center_crop=args.center_crop, template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, -- cgit v1.2.3-54-g00ecf