From 0462f5d20fdc82806753629b6f5c5fb39a88d1d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 09:19:52 +0200 Subject: Bugfix --- data/dreambooth/csv.py | 4 ++-- dreambooth.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4ebdc13..9075979 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -150,7 +150,7 @@ class CSVDataset(Dataset): max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_identifier: + if self.class_identifier is not None: class_image = Image.open(class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") @@ -177,7 +177,7 @@ class CSVDataset(Dataset): example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - if self.class_identifier: + if self.class_identifier is not None: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] diff --git a/dreambooth.py b/dreambooth.py index 1bff414..dd93e09 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -497,12 +497,12 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.class_identifier and "class_prompt_ids" in examples[0]: + if args.class_identifier is not None and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -529,7 +529,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - if args.class_identifier: + if args.class_identifier is not None: missing_data = [item for item in datamodule.data if not item[1].exists()] if len(missing_data) != 0: @@ -686,7 +686,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.class_identifier: + if args.class_identifier is not None: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) -- cgit v1.2.3-70-g09d2