diff options
-rw-r--r-- | data/dreambooth/csv.py | 4 | ||||
-rw-r--r-- | dreambooth.py | 8 |
2 files changed, 6 insertions, 6 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4ebdc13..9075979 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -150,7 +150,7 @@ class CSVDataset(Dataset): | |||
150 | max_length=self.tokenizer.model_max_length, | 150 | max_length=self.tokenizer.model_max_length, |
151 | ).input_ids | 151 | ).input_ids |
152 | 152 | ||
153 | if self.class_identifier: | 153 | if self.class_identifier is not None: |
154 | class_image = Image.open(class_image_path) | 154 | class_image = Image.open(class_image_path) |
155 | if not class_image.mode == "RGB": | 155 | if not class_image.mode == "RGB": |
156 | class_image = class_image.convert("RGB") | 156 | class_image = class_image.convert("RGB") |
@@ -177,7 +177,7 @@ class CSVDataset(Dataset): | |||
177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
179 | 179 | ||
180 | if self.class_identifier: | 180 | if self.class_identifier is not None: |
181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
183 | 183 | ||
diff --git a/dreambooth.py b/dreambooth.py index 1bff414..dd93e09 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -497,12 +497,12 @@ def main(): | |||
497 | pixel_values = [example["instance_images"] for example in examples] | 497 | pixel_values = [example["instance_images"] for example in examples] |
498 | 498 | ||
499 | # concat class and instance examples for prior preservation | 499 | # concat class and instance examples for prior preservation |
500 | if args.class_identifier and "class_prompt_ids" in examples[0]: | 500 | if args.class_identifier is not None and "class_prompt_ids" in examples[0]: |
501 | input_ids += [example["class_prompt_ids"] for example in examples] | 501 | input_ids += [example["class_prompt_ids"] for example in examples] |
502 | pixel_values += [example["class_images"] for example in examples] | 502 | pixel_values += [example["class_images"] for example in examples] |
503 | 503 | ||
504 | pixel_values = torch.stack(pixel_values) | 504 | pixel_values = torch.stack(pixel_values) |
505 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | 505 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) |
506 | 506 | ||
507 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 507 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
508 | 508 | ||
@@ -529,7 +529,7 @@ def main(): | |||
529 | datamodule.prepare_data() | 529 | datamodule.prepare_data() |
530 | datamodule.setup() | 530 | datamodule.setup() |
531 | 531 | ||
532 | if args.class_identifier: | 532 | if args.class_identifier is not None: |
533 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 533 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
534 | 534 | ||
535 | if len(missing_data) != 0: | 535 | if len(missing_data) != 0: |
@@ -686,7 +686,7 @@ def main(): | |||
686 | # Predict the noise residual | 686 | # Predict the noise residual |
687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
688 | 688 | ||
689 | if args.class_identifier: | 689 | if args.class_identifier is not None: |
690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |