From 6c38d0088ece492696a7bc94a5cb43a48289452a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:35:42 +0100 Subject: Fix --- data/csv.py | 2 +- train_dreambooth.py | 4 ++-- train_ti.py | 4 ++-- training/common.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data/csv.py b/data/csv.py index df3ee77..b058a3e 100644 --- a/data/csv.py +++ b/data/csv.py @@ -121,7 +121,7 @@ def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): inputs = unify_input_ids(tokenizer, input_ids) batch = { - "with_prior": torch.tensor(with_prior), + "with_prior": torch.tensor([with_prior] * len(examples)), "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, diff --git a/train_dreambooth.py b/train_dreambooth.py index c180170..53776ba 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -523,7 +523,7 @@ class Checkpointer(CheckpointerBase): torch.cuda.empty_cache() @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + def save_samples(self, step): unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) @@ -545,7 +545,7 @@ class Checkpointer(CheckpointerBase): ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + super().save_samples(pipeline, step) unet.to(dtype=orig_unet_dtype) text_encoder.to(dtype=orig_text_encoder_dtype) diff --git a/train_ti.py b/train_ti.py index d752927..928b721 100644 --- a/train_ti.py +++ b/train_ti.py @@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase): del text_encoder @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + def save_samples(self, step): text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( @@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase): ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + super().save_samples(pipeline, step) text_encoder.to(dtype=orig_dtype) diff --git a/training/common.py b/training/common.py index f5ab326..8083137 100644 --- a/training/common.py +++ b/training/common.py @@ -184,7 +184,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if batch["with_prior"]: + if batch["with_prior"].all(): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) -- cgit v1.2.3-70-g09d2