From 0462f5d20fdc82806753629b6f5c5fb39a88d1d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 09:19:52 +0200 Subject: Bugfix --- dreambooth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 1bff414..dd93e09 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -497,12 +497,12 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.class_identifier and "class_prompt_ids" in examples[0]: + if args.class_identifier is not None and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -529,7 +529,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - if args.class_identifier: + if args.class_identifier is not None: missing_data = [item for item in datamodule.data if not item[1].exists()] if len(missing_data) != 0: @@ -686,7 +686,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.class_identifier: + if args.class_identifier is not None: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) -- cgit v1.2.3-54-g00ecf