From b491a817088790219e052b86173e128c55b597f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 21:53:46 +0100 Subject: num_class_images is now class images per train image --- data/csv.py | 4 ++-- train_dreambooth.py | 2 +- train_ti.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data/csv.py b/data/csv.py index e25dd3f..edce2b1 100644 --- a/data/csv.py +++ b/data/csv.py @@ -49,7 +49,7 @@ class CSVDataModule(): data_file: str, prompt_processor: PromptProcessor, class_subdir: str = "cls", - num_class_images: int = 100, + num_class_images: int = 1, size: int = 512, repeats: int = 1, dropout: float = 0, @@ -117,7 +117,7 @@ class CSVDataModule(): return [item for item in items if self.filter(item)] def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: - image_multiplier = max(math.ceil(num_class_images / len(items)), 1) + image_multiplier = max(num_class_images, 1) return [ CSVDataItem( diff --git a/train_dreambooth.py b/train_dreambooth.py index ff67d12..2f913e7 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -101,7 +101,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=400, + default=1, help="How many class images to generate." ) parser.add_argument( diff --git a/train_ti.py b/train_ti.py index 55daa35..e272b5d 100644 --- a/train_ti.py +++ b/train_ti.py @@ -83,7 +83,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=400, + default=1, help="How many class images to generate." ) parser.add_argument( -- cgit v1.2.3-54-g00ecf