summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py2
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_ti.py4
-rw-r--r--training/common.py2
4 files changed, 6 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py
index df3ee77..b058a3e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -121,7 +121,7 @@ def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples):
121 inputs = unify_input_ids(tokenizer, input_ids) 121 inputs = unify_input_ids(tokenizer, input_ids)
122 122
123 batch = { 123 batch = {
124 "with_prior": torch.tensor(with_prior), 124 "with_prior": torch.tensor([with_prior] * len(examples)),
125 "prompt_ids": prompts.input_ids, 125 "prompt_ids": prompts.input_ids,
126 "nprompt_ids": nprompts.input_ids, 126 "nprompt_ids": nprompts.input_ids,
127 "input_ids": inputs.input_ids, 127 "input_ids": inputs.input_ids,
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c180170..53776ba 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -523,7 +523,7 @@ class Checkpointer(CheckpointerBase):
523 torch.cuda.empty_cache() 523 torch.cuda.empty_cache()
524 524
525 @torch.no_grad() 525 @torch.no_grad()
526 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 526 def save_samples(self, step):
527 unet = self.accelerator.unwrap_model(self.unet) 527 unet = self.accelerator.unwrap_model(self.unet)
528 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 528 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
529 529
@@ -545,7 +545,7 @@ class Checkpointer(CheckpointerBase):
545 ).to(self.accelerator.device) 545 ).to(self.accelerator.device)
546 pipeline.set_progress_bar_config(dynamic_ncols=True) 546 pipeline.set_progress_bar_config(dynamic_ncols=True)
547 547
548 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 548 super().save_samples(pipeline, step)
549 549
550 unet.to(dtype=orig_unet_dtype) 550 unet.to(dtype=orig_unet_dtype)
551 text_encoder.to(dtype=orig_text_encoder_dtype) 551 text_encoder.to(dtype=orig_text_encoder_dtype)
diff --git a/train_ti.py b/train_ti.py
index d752927..928b721 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase):
531 del text_encoder 531 del text_encoder
532 532
533 @torch.no_grad() 533 @torch.no_grad()
534 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 534 def save_samples(self, step):
535 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 535 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
536 536
537 ema_context = self.ema_embeddings.apply_temporary( 537 ema_context = self.ema_embeddings.apply_temporary(
@@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase):
550 ).to(self.accelerator.device) 550 ).to(self.accelerator.device)
551 pipeline.set_progress_bar_config(dynamic_ncols=True) 551 pipeline.set_progress_bar_config(dynamic_ncols=True)
552 552
553 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 553 super().save_samples(pipeline, step)
554 554
555 text_encoder.to(dtype=orig_dtype) 555 text_encoder.to(dtype=orig_dtype)
556 556
diff --git a/training/common.py b/training/common.py
index f5ab326..8083137 100644
--- a/training/common.py
+++ b/training/common.py
@@ -184,7 +184,7 @@ def loss_step(
184 else: 184 else:
185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
186 186
187 if batch["with_prior"]: 187 if batch["with_prior"].all():
188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
190 target, target_prior = torch.chunk(target, 2, dim=0) 190 target, target_prior = torch.chunk(target, 2, dim=0)