diff options
| -rw-r--r-- | dreambooth.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/dreambooth.py b/dreambooth.py index dd93e09..c8d18a8 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -296,6 +296,7 @@ class Checkpointer: | |||
| 296 | tokenizer, | 296 | tokenizer, |
| 297 | text_encoder, | 297 | text_encoder, |
| 298 | output_dir, | 298 | output_dir, |
| 299 | instance_identifier, | ||
| 299 | sample_image_size, | 300 | sample_image_size, |
| 300 | sample_batches, | 301 | sample_batches, |
| 301 | sample_batch_size, | 302 | sample_batch_size, |
| @@ -308,6 +309,7 @@ class Checkpointer: | |||
| 308 | self.tokenizer = tokenizer | 309 | self.tokenizer = tokenizer |
| 309 | self.text_encoder = text_encoder | 310 | self.text_encoder = text_encoder |
| 310 | self.output_dir = output_dir | 311 | self.output_dir = output_dir |
| 312 | self.instance_identifier = instance_identifier | ||
| 311 | self.sample_image_size = sample_image_size | 313 | self.sample_image_size = sample_image_size |
| 312 | self.seed = seed | 314 | self.seed = seed |
| 313 | self.sample_batches = sample_batches | 315 | self.sample_batches = sample_batches |
| @@ -373,7 +375,8 @@ class Checkpointer: | |||
| 373 | 375 | ||
| 374 | for i in range(self.sample_batches): | 376 | for i in range(self.sample_batches): |
| 375 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 377 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 376 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 378 | prompt = [prompt.format(self.instance_identifier) |
| 379 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | ||
| 377 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 380 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 378 | 381 | ||
| 379 | with self.accelerator.autocast(): | 382 | with self.accelerator.autocast(): |
| @@ -551,7 +554,7 @@ def main(): | |||
| 551 | 554 | ||
| 552 | for batch in batched_data: | 555 | for batch in batched_data: |
| 553 | image_name = [p[1] for p in batch] | 556 | image_name = [p[1] for p in batch] |
| 554 | prompt = [p[2] for p in batch] | 557 | prompt = [p[2].format(args.class_identifier) for p in batch] |
| 555 | nprompt = [p[3] for p in batch] | 558 | nprompt = [p[3] for p in batch] |
| 556 | 559 | ||
| 557 | with accelerator.autocast(): | 560 | with accelerator.autocast(): |
| @@ -580,6 +583,7 @@ def main(): | |||
| 580 | tokenizer=tokenizer, | 583 | tokenizer=tokenizer, |
| 581 | text_encoder=text_encoder, | 584 | text_encoder=text_encoder, |
| 582 | output_dir=basepath, | 585 | output_dir=basepath, |
| 586 | instance_identifier=args.instance_identifier, | ||
| 583 | sample_image_size=args.sample_image_size, | 587 | sample_image_size=args.sample_image_size, |
| 584 | sample_batch_size=args.sample_batch_size, | 588 | sample_batch_size=args.sample_batch_size, |
| 585 | sample_batches=args.sample_batches, | 589 | sample_batches=args.sample_batches, |
