summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-24 08:32:02 +0100
committerVolpeon <git@volpeon.ink>2022-12-24 08:32:02 +0100
commit7218588c03be1e7fa5566b8836826e6b1c9065d2 (patch)
treea3f2c931e9f4e1469c3d1b524f4587d79904720f
parentBetter dataset prompt handling (diff)
downloadtextual-inversion-diff-7218588c03be1e7fa5566b8836826e6b1c9065d2.tar.gz
textual-inversion-diff-7218588c03be1e7fa5566b8836826e6b1c9065d2.tar.bz2
textual-inversion-diff-7218588c03be1e7fa5566b8836826e6b1c9065d2.zip
Fix
-rw-r--r--data/csv.py4
-rw-r--r--training/util.py5
2 files changed, 6 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 265293b..b45ac77 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -106,11 +106,11 @@ class CSVDataModule():
106 expansions 106 expansions
107 ), 107 ),
108 prompt_to_keywords( 108 prompt_to_keywords(
109 cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), 109 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
110 expansions 110 expansions
111 ), 111 ),
112 prompt_to_keywords( 112 prompt_to_keywords(
113 prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 113 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
114 expansions 114 expansions
115 ), 115 ),
116 ) 116 )
diff --git a/training/util.py b/training/util.py
index 000173d..a623dc5 100644
--- a/training/util.py
+++ b/training/util.py
@@ -74,6 +74,9 @@ class CheckpointerBase:
74 generator=generator, 74 generator=generator,
75 ) 75 )
76 76
77 grid_cols = max(self.sample_batch_size, 4)
78 grid_rows = self.sample_batches * self.sample_batch_size / grid_cols
79
77 with torch.autocast("cuda"), torch.inference_mode(): 80 with torch.autocast("cuda"), torch.inference_mode():
78 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 81 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
79 all_samples = [] 82 all_samples = []
@@ -119,7 +122,7 @@ class CheckpointerBase:
119 122
120 del samples 123 del samples
121 124
122 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 125 image_grid = make_grid(all_samples, grid_rows, grid_cols)
123 image_grid.save(file_path, quality=85) 126 image_grid.save(file_path, quality=85)
124 127
125 del all_samples 128 del all_samples