diff options
| -rw-r--r-- | data/dreambooth/csv.py | 4 | ||||
| -rw-r--r-- | dreambooth.py | 8 |
2 files changed, 6 insertions, 6 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4ebdc13..9075979 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -150,7 +150,7 @@ class CSVDataset(Dataset): | |||
| 150 | max_length=self.tokenizer.model_max_length, | 150 | max_length=self.tokenizer.model_max_length, |
| 151 | ).input_ids | 151 | ).input_ids |
| 152 | 152 | ||
| 153 | if self.class_identifier: | 153 | if self.class_identifier is not None: |
| 154 | class_image = Image.open(class_image_path) | 154 | class_image = Image.open(class_image_path) |
| 155 | if not class_image.mode == "RGB": | 155 | if not class_image.mode == "RGB": |
| 156 | class_image = class_image.convert("RGB") | 156 | class_image = class_image.convert("RGB") |
| @@ -177,7 +177,7 @@ class CSVDataset(Dataset): | |||
| 177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
| 179 | 179 | ||
| 180 | if self.class_identifier: | 180 | if self.class_identifier is not None: |
| 181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
| 183 | 183 | ||
diff --git a/dreambooth.py b/dreambooth.py index 1bff414..dd93e09 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -497,12 +497,12 @@ def main(): | |||
| 497 | pixel_values = [example["instance_images"] for example in examples] | 497 | pixel_values = [example["instance_images"] for example in examples] |
| 498 | 498 | ||
| 499 | # concat class and instance examples for prior preservation | 499 | # concat class and instance examples for prior preservation |
| 500 | if args.class_identifier and "class_prompt_ids" in examples[0]: | 500 | if args.class_identifier is not None and "class_prompt_ids" in examples[0]: |
| 501 | input_ids += [example["class_prompt_ids"] for example in examples] | 501 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 502 | pixel_values += [example["class_images"] for example in examples] | 502 | pixel_values += [example["class_images"] for example in examples] |
| 503 | 503 | ||
| 504 | pixel_values = torch.stack(pixel_values) | 504 | pixel_values = torch.stack(pixel_values) |
| 505 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | 505 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) |
| 506 | 506 | ||
| 507 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 507 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
| 508 | 508 | ||
| @@ -529,7 +529,7 @@ def main(): | |||
| 529 | datamodule.prepare_data() | 529 | datamodule.prepare_data() |
| 530 | datamodule.setup() | 530 | datamodule.setup() |
| 531 | 531 | ||
| 532 | if args.class_identifier: | 532 | if args.class_identifier is not None: |
| 533 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 533 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
| 534 | 534 | ||
| 535 | if len(missing_data) != 0: | 535 | if len(missing_data) != 0: |
| @@ -686,7 +686,7 @@ def main(): | |||
| 686 | # Predict the noise residual | 686 | # Predict the noise residual |
| 687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 688 | 688 | ||
| 689 | if args.class_identifier: | 689 | if args.class_identifier is not None: |
| 690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
