summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
commit1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch)
tree42d3302610046dbc5d39d254f7a2d5d5f601aa18 /training/util.py
parentFix (diff)
downloadtextual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip
Improved class prompt handling
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/training/util.py b/training/util.py
index e8d22ae..000173d 100644
--- a/training/util.py
+++ b/training/util.py
@@ -44,7 +44,6 @@ class CheckpointerBase:
44 self, 44 self,
45 datamodule, 45 datamodule,
46 output_dir: Path, 46 output_dir: Path,
47 instance_identifier,
48 placeholder_token, 47 placeholder_token,
49 placeholder_token_id, 48 placeholder_token_id,
50 sample_image_size, 49 sample_image_size,
@@ -54,7 +53,6 @@ class CheckpointerBase:
54 ): 53 ):
55 self.datamodule = datamodule 54 self.datamodule = datamodule
56 self.output_dir = output_dir 55 self.output_dir = output_dir
57 self.instance_identifier = instance_identifier
58 self.placeholder_token = placeholder_token 56 self.placeholder_token = placeholder_token
59 self.placeholder_token_id = placeholder_token_id 57 self.placeholder_token_id = placeholder_token_id
60 self.sample_image_size = sample_image_size 58 self.sample_image_size = sample_image_size
@@ -90,7 +88,7 @@ class CheckpointerBase:
90 if j * data.batch_size < self.sample_batch_size * self.sample_batches 88 if j * data.batch_size < self.sample_batch_size * self.sample_batches
91 ] 89 ]
92 prompts = [ 90 prompts = [
93 prompt.format(identifier=self.instance_identifier) 91 prompt
94 for batch in batches 92 for batch in batches
95 for prompt in batch["prompts"] 93 for prompt in batch["prompts"]
96 ] 94 ]