summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-04 09:19:52 +0200
committerVolpeon <git@volpeon.ink>2022-10-04 09:19:52 +0200
commit0462f5d20fdc82806753629b6f5c5fb39a88d1d2 (patch)
treefb9f73860a597606539471a27f583c4a8337961d
parentFix (diff)
downloadtextual-inversion-diff-0462f5d20fdc82806753629b6f5c5fb39a88d1d2.tar.gz
textual-inversion-diff-0462f5d20fdc82806753629b6f5c5fb39a88d1d2.tar.bz2
textual-inversion-diff-0462f5d20fdc82806753629b6f5c5fb39a88d1d2.zip
Bugfix
-rw-r--r--data/dreambooth/csv.py4
-rw-r--r--dreambooth.py8
2 files changed, 6 insertions, 6 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 4ebdc13..9075979 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -150,7 +150,7 @@ class CSVDataset(Dataset):
150 max_length=self.tokenizer.model_max_length, 150 max_length=self.tokenizer.model_max_length,
151 ).input_ids 151 ).input_ids
152 152
153 if self.class_identifier: 153 if self.class_identifier is not None:
154 class_image = Image.open(class_image_path) 154 class_image = Image.open(class_image_path)
155 if not class_image.mode == "RGB": 155 if not class_image.mode == "RGB":
156 class_image = class_image.convert("RGB") 156 class_image = class_image.convert("RGB")
@@ -177,7 +177,7 @@ class CSVDataset(Dataset):
177 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 177 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
178 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 178 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
179 179
180 if self.class_identifier: 180 if self.class_identifier is not None:
181 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 181 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
182 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 182 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
183 183
diff --git a/dreambooth.py b/dreambooth.py
index 1bff414..dd93e09 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -497,12 +497,12 @@ def main():
497 pixel_values = [example["instance_images"] for example in examples] 497 pixel_values = [example["instance_images"] for example in examples]
498 498
499 # concat class and instance examples for prior preservation 499 # concat class and instance examples for prior preservation
500 if args.class_identifier and "class_prompt_ids" in examples[0]: 500 if args.class_identifier is not None and "class_prompt_ids" in examples[0]:
501 input_ids += [example["class_prompt_ids"] for example in examples] 501 input_ids += [example["class_prompt_ids"] for example in examples]
502 pixel_values += [example["class_images"] for example in examples] 502 pixel_values += [example["class_images"] for example in examples]
503 503
504 pixel_values = torch.stack(pixel_values) 504 pixel_values = torch.stack(pixel_values)
505 pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 505 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
506 506
507 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 507 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
508 508
@@ -529,7 +529,7 @@ def main():
529 datamodule.prepare_data() 529 datamodule.prepare_data()
530 datamodule.setup() 530 datamodule.setup()
531 531
532 if args.class_identifier: 532 if args.class_identifier is not None:
533 missing_data = [item for item in datamodule.data if not item[1].exists()] 533 missing_data = [item for item in datamodule.data if not item[1].exists()]
534 534
535 if len(missing_data) != 0: 535 if len(missing_data) != 0:
@@ -686,7 +686,7 @@ def main():
686 # Predict the noise residual 686 # Predict the noise residual
687 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 687 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
688 688
689 if args.class_identifier: 689 if args.class_identifier is not None:
690 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 690 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
691 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 691 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
692 noise, noise_prior = torch.chunk(noise, 2, dim=0) 692 noise, noise_prior = torch.chunk(noise, 2, dim=0)