From 63a5ed820d2b635a73cbd260f957027a5d2cc605 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 09:56:44 +0200 Subject: Fix: Prompts --- dreambooth.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index dd93e09..c8d18a8 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -296,6 +296,7 @@ class Checkpointer: tokenizer, text_encoder, output_dir, + instance_identifier, sample_image_size, sample_batches, sample_batch_size, @@ -308,6 +309,7 @@ class Checkpointer: self.tokenizer = tokenizer self.text_encoder = text_encoder self.output_dir = output_dir + self.instance_identifier = instance_identifier self.sample_image_size = sample_image_size self.seed = seed self.sample_batches = sample_batches @@ -373,7 +375,8 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + prompt = [prompt.format(self.instance_identifier) + for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] with self.accelerator.autocast(): @@ -551,7 +554,7 @@ def main(): for batch in batched_data: image_name = [p[1] for p in batch] - prompt = [p[2] for p in batch] + prompt = [p[2].format(args.class_identifier) for p in batch] nprompt = [p[3] for p in batch] with accelerator.autocast(): @@ -580,6 +583,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, output_dir=basepath, + instance_identifier=args.instance_identifier, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, -- cgit v1.2.3-54-g00ecf