diff options
-rw-r--r-- | dreambooth.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/dreambooth.py b/dreambooth.py index dd93e09..c8d18a8 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -296,6 +296,7 @@ class Checkpointer: | |||
296 | tokenizer, | 296 | tokenizer, |
297 | text_encoder, | 297 | text_encoder, |
298 | output_dir, | 298 | output_dir, |
299 | instance_identifier, | ||
299 | sample_image_size, | 300 | sample_image_size, |
300 | sample_batches, | 301 | sample_batches, |
301 | sample_batch_size, | 302 | sample_batch_size, |
@@ -308,6 +309,7 @@ class Checkpointer: | |||
308 | self.tokenizer = tokenizer | 309 | self.tokenizer = tokenizer |
309 | self.text_encoder = text_encoder | 310 | self.text_encoder = text_encoder |
310 | self.output_dir = output_dir | 311 | self.output_dir = output_dir |
312 | self.instance_identifier = instance_identifier | ||
311 | self.sample_image_size = sample_image_size | 313 | self.sample_image_size = sample_image_size |
312 | self.seed = seed | 314 | self.seed = seed |
313 | self.sample_batches = sample_batches | 315 | self.sample_batches = sample_batches |
@@ -373,7 +375,8 @@ class Checkpointer: | |||
373 | 375 | ||
374 | for i in range(self.sample_batches): | 376 | for i in range(self.sample_batches): |
375 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 377 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
376 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 378 | prompt = [prompt.format(self.instance_identifier) |
379 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | ||
377 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 380 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
378 | 381 | ||
379 | with self.accelerator.autocast(): | 382 | with self.accelerator.autocast(): |
@@ -551,7 +554,7 @@ def main(): | |||
551 | 554 | ||
552 | for batch in batched_data: | 555 | for batch in batched_data: |
553 | image_name = [p[1] for p in batch] | 556 | image_name = [p[1] for p in batch] |
554 | prompt = [p[2] for p in batch] | 557 | prompt = [p[2].format(args.class_identifier) for p in batch] |
555 | nprompt = [p[3] for p in batch] | 558 | nprompt = [p[3] for p in batch] |
556 | 559 | ||
557 | with accelerator.autocast(): | 560 | with accelerator.autocast(): |
@@ -580,6 +583,7 @@ def main(): | |||
580 | tokenizer=tokenizer, | 583 | tokenizer=tokenizer, |
581 | text_encoder=text_encoder, | 584 | text_encoder=text_encoder, |
582 | output_dir=basepath, | 585 | output_dir=basepath, |
586 | instance_identifier=args.instance_identifier, | ||
583 | sample_image_size=args.sample_image_size, | 587 | sample_image_size=args.sample_image_size, |
584 | sample_batch_size=args.sample_batch_size, | 588 | sample_batch_size=args.sample_batch_size, |
585 | sample_batches=args.sample_batches, | 589 | sample_batches=args.sample_batches, |