summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/util.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/training/util.py b/training/util.py
index e8d22ae..000173d 100644
--- a/training/util.py
+++ b/training/util.py
@@ -44,7 +44,6 @@ class CheckpointerBase:
44 self, 44 self,
45 datamodule, 45 datamodule,
46 output_dir: Path, 46 output_dir: Path,
47 instance_identifier,
48 placeholder_token, 47 placeholder_token,
49 placeholder_token_id, 48 placeholder_token_id,
50 sample_image_size, 49 sample_image_size,
@@ -54,7 +53,6 @@ class CheckpointerBase:
54 ): 53 ):
55 self.datamodule = datamodule 54 self.datamodule = datamodule
56 self.output_dir = output_dir 55 self.output_dir = output_dir
57 self.instance_identifier = instance_identifier
58 self.placeholder_token = placeholder_token 56 self.placeholder_token = placeholder_token
59 self.placeholder_token_id = placeholder_token_id 57 self.placeholder_token_id = placeholder_token_id
60 self.sample_image_size = sample_image_size 58 self.sample_image_size = sample_image_size
@@ -90,7 +88,7 @@ class CheckpointerBase:
90 if j * data.batch_size < self.sample_batch_size * self.sample_batches 88 if j * data.batch_size < self.sample_batch_size * self.sample_batches
91 ] 89 ]
92 prompts = [ 90 prompts = [
93 prompt.format(identifier=self.instance_identifier) 91 prompt
94 for batch in batches 92 for batch in batches
95 for prompt in batch["prompts"] 93 for prompt in batch["prompts"]
96 ] 94 ]