diff options
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 4 | ||||
| -rw-r--r-- | train_ti.py | 4 | ||||
| -rw-r--r-- | training/common.py | 2 |
4 files changed, 6 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py index df3ee77..b058a3e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -121,7 +121,7 @@ def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): | |||
| 121 | inputs = unify_input_ids(tokenizer, input_ids) | 121 | inputs = unify_input_ids(tokenizer, input_ids) |
| 122 | 122 | ||
| 123 | batch = { | 123 | batch = { |
| 124 | "with_prior": torch.tensor(with_prior), | 124 | "with_prior": torch.tensor([with_prior] * len(examples)), |
| 125 | "prompt_ids": prompts.input_ids, | 125 | "prompt_ids": prompts.input_ids, |
| 126 | "nprompt_ids": nprompts.input_ids, | 126 | "nprompt_ids": nprompts.input_ids, |
| 127 | "input_ids": inputs.input_ids, | 127 | "input_ids": inputs.input_ids, |
diff --git a/train_dreambooth.py b/train_dreambooth.py index c180170..53776ba 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -523,7 +523,7 @@ class Checkpointer(CheckpointerBase): | |||
| 523 | torch.cuda.empty_cache() | 523 | torch.cuda.empty_cache() |
| 524 | 524 | ||
| 525 | @torch.no_grad() | 525 | @torch.no_grad() |
| 526 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 526 | def save_samples(self, step): |
| 527 | unet = self.accelerator.unwrap_model(self.unet) | 527 | unet = self.accelerator.unwrap_model(self.unet) |
| 528 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 528 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 529 | 529 | ||
| @@ -545,7 +545,7 @@ class Checkpointer(CheckpointerBase): | |||
| 545 | ).to(self.accelerator.device) | 545 | ).to(self.accelerator.device) |
| 546 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 546 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 547 | 547 | ||
| 548 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 548 | super().save_samples(pipeline, step) |
| 549 | 549 | ||
| 550 | unet.to(dtype=orig_unet_dtype) | 550 | unet.to(dtype=orig_unet_dtype) |
| 551 | text_encoder.to(dtype=orig_text_encoder_dtype) | 551 | text_encoder.to(dtype=orig_text_encoder_dtype) |
diff --git a/train_ti.py b/train_ti.py index d752927..928b721 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase): | |||
| 531 | del text_encoder | 531 | del text_encoder |
| 532 | 532 | ||
| 533 | @torch.no_grad() | 533 | @torch.no_grad() |
| 534 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 534 | def save_samples(self, step): |
| 535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 536 | 536 | ||
| 537 | ema_context = self.ema_embeddings.apply_temporary( | 537 | ema_context = self.ema_embeddings.apply_temporary( |
| @@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase): | |||
| 550 | ).to(self.accelerator.device) | 550 | ).to(self.accelerator.device) |
| 551 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 551 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 552 | 552 | ||
| 553 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 553 | super().save_samples(pipeline, step) |
| 554 | 554 | ||
| 555 | text_encoder.to(dtype=orig_dtype) | 555 | text_encoder.to(dtype=orig_dtype) |
| 556 | 556 | ||
diff --git a/training/common.py b/training/common.py index f5ab326..8083137 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -184,7 +184,7 @@ def loss_step( | |||
| 184 | else: | 184 | else: |
| 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 186 | 186 | ||
| 187 | if batch["with_prior"]: | 187 | if batch["with_prior"].all(): |
| 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 190 | target, target_prior = torch.chunk(target, 2, dim=0) | 190 | target, target_prior = torch.chunk(target, 2, dim=0) |
