diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
| commit | 7505f7e843dc719622a15f4ee301609813763d77 (patch) | |
| tree | fe67640dce9fec4f625d6d1600c696cd7de006ee /train_dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2 textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip | |
Code simplifications, avoid autocast
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index e239833..2c765ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -389,6 +389,7 @@ def parse_args(): | |||
| 389 | class Checkpointer(CheckpointerBase): | 389 | class Checkpointer(CheckpointerBase): |
| 390 | def __init__( | 390 | def __init__( |
| 391 | self, | 391 | self, |
| 392 | weight_dtype, | ||
| 392 | datamodule, | 393 | datamodule, |
| 393 | accelerator, | 394 | accelerator, |
| 394 | vae, | 395 | vae, |
| @@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase): | |||
| 416 | sample_batch_size=sample_batch_size | 417 | sample_batch_size=sample_batch_size |
| 417 | ) | 418 | ) |
| 418 | 419 | ||
| 420 | self.weight_dtype = weight_dtype | ||
| 419 | self.accelerator = accelerator | 421 | self.accelerator = accelerator |
| 420 | self.vae = vae | 422 | self.vae = vae |
| 421 | self.unet = unet | 423 | self.unet = unet |
| @@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase): | |||
| 452 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 454 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
| 453 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 455 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 454 | 456 | ||
| 457 | orig_unet_dtype = unet.dtype | ||
| 458 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 459 | |||
| 460 | unet.to(dtype=self.weight_dtype) | ||
| 461 | text_encoder.to(dtype=self.weight_dtype) | ||
| 462 | |||
| 455 | pipeline = VlpnStableDiffusion( | 463 | pipeline = VlpnStableDiffusion( |
| 456 | text_encoder=text_encoder, | 464 | text_encoder=text_encoder, |
| 457 | vae=self.vae, | 465 | vae=self.vae, |
| @@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase): | |||
| 463 | 471 | ||
| 464 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 472 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 465 | 473 | ||
| 474 | unet.to(dtype=orig_unet_dtype) | ||
| 475 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
| 476 | |||
| 466 | del unet | 477 | del unet |
| 467 | del text_encoder | 478 | del text_encoder |
| 468 | del pipeline | 479 | del pipeline |
| @@ -798,6 +809,7 @@ def main(): | |||
| 798 | max_acc_val = 0.0 | 809 | max_acc_val = 0.0 |
| 799 | 810 | ||
| 800 | checkpointer = Checkpointer( | 811 | checkpointer = Checkpointer( |
| 812 | weight_dtype=weight_dtype, | ||
| 801 | datamodule=datamodule, | 813 | datamodule=datamodule, |
| 802 | accelerator=accelerator, | 814 | accelerator=accelerator, |
| 803 | vae=vae, | 815 | vae=vae, |
