summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 21:53:46 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 21:53:46 +0100
commitb491a817088790219e052b86173e128c55b597f8 (patch)
tree49ea8480da78b002f36b15b52d1b9bc501837819
parentImproved class prompt handling (diff)
downloadtextual-inversion-diff-b491a817088790219e052b86173e128c55b597f8.tar.gz
textual-inversion-diff-b491a817088790219e052b86173e128c55b597f8.tar.bz2
textual-inversion-diff-b491a817088790219e052b86173e128c55b597f8.zip
num_class_images is now class images per train image
-rw-r--r--data/csv.py4
-rw-r--r--train_dreambooth.py2
-rw-r--r--train_ti.py2
3 files changed, 4 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py
index e25dd3f..edce2b1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -49,7 +49,7 @@ class CSVDataModule():
49 data_file: str, 49 data_file: str,
50 prompt_processor: PromptProcessor, 50 prompt_processor: PromptProcessor,
51 class_subdir: str = "cls", 51 class_subdir: str = "cls",
52 num_class_images: int = 100, 52 num_class_images: int = 1,
53 size: int = 512, 53 size: int = 512,
54 repeats: int = 1, 54 repeats: int = 1,
55 dropout: float = 0, 55 dropout: float = 0,
@@ -117,7 +117,7 @@ class CSVDataModule():
117 return [item for item in items if self.filter(item)] 117 return [item for item in items if self.filter(item)]
118 118
119 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: 119 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]:
120 image_multiplier = max(math.ceil(num_class_images / len(items)), 1) 120 image_multiplier = max(num_class_images, 1)
121 121
122 return [ 122 return [
123 CSVDataItem( 123 CSVDataItem(
diff --git a/train_dreambooth.py b/train_dreambooth.py
index ff67d12..2f913e7 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -101,7 +101,7 @@ def parse_args():
101 parser.add_argument( 101 parser.add_argument(
102 "--num_class_images", 102 "--num_class_images",
103 type=int, 103 type=int,
104 default=400, 104 default=1,
105 help="How many class images to generate." 105 help="How many class images to generate."
106 ) 106 )
107 parser.add_argument( 107 parser.add_argument(
diff --git a/train_ti.py b/train_ti.py
index 55daa35..e272b5d 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -83,7 +83,7 @@ def parse_args():
83 parser.add_argument( 83 parser.add_argument(
84 "--num_class_images", 84 "--num_class_images",
85 type=int, 85 type=int,
86 default=400, 86 default=1,
87 help="How many class images to generate." 87 help="How many class images to generate."
88 ) 88 )
89 parser.add_argument( 89 parser.add_argument(