summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/dreambooth.py b/dreambooth.py
index dd93e09..c8d18a8 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -296,6 +296,7 @@ class Checkpointer:
296 tokenizer, 296 tokenizer,
297 text_encoder, 297 text_encoder,
298 output_dir, 298 output_dir,
299 instance_identifier,
299 sample_image_size, 300 sample_image_size,
300 sample_batches, 301 sample_batches,
301 sample_batch_size, 302 sample_batch_size,
@@ -308,6 +309,7 @@ class Checkpointer:
308 self.tokenizer = tokenizer 309 self.tokenizer = tokenizer
309 self.text_encoder = text_encoder 310 self.text_encoder = text_encoder
310 self.output_dir = output_dir 311 self.output_dir = output_dir
312 self.instance_identifier = instance_identifier
311 self.sample_image_size = sample_image_size 313 self.sample_image_size = sample_image_size
312 self.seed = seed 314 self.seed = seed
313 self.sample_batches = sample_batches 315 self.sample_batches = sample_batches
@@ -373,7 +375,8 @@ class Checkpointer:
373 375
374 for i in range(self.sample_batches): 376 for i in range(self.sample_batches):
375 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 377 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
376 prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 378 prompt = [prompt.format(self.instance_identifier)
379 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
377 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 380 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
378 381
379 with self.accelerator.autocast(): 382 with self.accelerator.autocast():
@@ -551,7 +554,7 @@ def main():
551 554
552 for batch in batched_data: 555 for batch in batched_data:
553 image_name = [p[1] for p in batch] 556 image_name = [p[1] for p in batch]
554 prompt = [p[2] for p in batch] 557 prompt = [p[2].format(args.class_identifier) for p in batch]
555 nprompt = [p[3] for p in batch] 558 nprompt = [p[3] for p in batch]
556 559
557 with accelerator.autocast(): 560 with accelerator.autocast():
@@ -580,6 +583,7 @@ def main():
580 tokenizer=tokenizer, 583 tokenizer=tokenizer,
581 text_encoder=text_encoder, 584 text_encoder=text_encoder,
582 output_dir=basepath, 585 output_dir=basepath,
586 instance_identifier=args.instance_identifier,
583 sample_image_size=args.sample_image_size, 587 sample_image_size=args.sample_image_size,
584 sample_batch_size=args.sample_batch_size, 588 sample_batch_size=args.sample_batch_size,
585 sample_batches=args.sample_batches, 589 sample_batches=args.sample_batches,