diff options
author | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
commit | 7505f7e843dc719622a15f4ee301609813763d77 (patch) | |
tree | fe67640dce9fec4f625d6d1600c696cd7de006ee /train_dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2 textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip |
Code simplifications, avoid autocast
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index e239833..2c765ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -389,6 +389,7 @@ def parse_args(): | |||
389 | class Checkpointer(CheckpointerBase): | 389 | class Checkpointer(CheckpointerBase): |
390 | def __init__( | 390 | def __init__( |
391 | self, | 391 | self, |
392 | weight_dtype, | ||
392 | datamodule, | 393 | datamodule, |
393 | accelerator, | 394 | accelerator, |
394 | vae, | 395 | vae, |
@@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase): | |||
416 | sample_batch_size=sample_batch_size | 417 | sample_batch_size=sample_batch_size |
417 | ) | 418 | ) |
418 | 419 | ||
420 | self.weight_dtype = weight_dtype | ||
419 | self.accelerator = accelerator | 421 | self.accelerator = accelerator |
420 | self.vae = vae | 422 | self.vae = vae |
421 | self.unet = unet | 423 | self.unet = unet |
@@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase): | |||
452 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 454 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
453 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 455 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
454 | 456 | ||
457 | orig_unet_dtype = unet.dtype | ||
458 | orig_text_encoder_dtype = text_encoder.dtype | ||
459 | |||
460 | unet.to(dtype=self.weight_dtype) | ||
461 | text_encoder.to(dtype=self.weight_dtype) | ||
462 | |||
455 | pipeline = VlpnStableDiffusion( | 463 | pipeline = VlpnStableDiffusion( |
456 | text_encoder=text_encoder, | 464 | text_encoder=text_encoder, |
457 | vae=self.vae, | 465 | vae=self.vae, |
@@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase): | |||
463 | 471 | ||
464 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 472 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
465 | 473 | ||
474 | unet.to(dtype=orig_unet_dtype) | ||
475 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
476 | |||
466 | del unet | 477 | del unet |
467 | del text_encoder | 478 | del text_encoder |
468 | del pipeline | 479 | del pipeline |
@@ -798,6 +809,7 @@ def main(): | |||
798 | max_acc_val = 0.0 | 809 | max_acc_val = 0.0 |
799 | 810 | ||
800 | checkpointer = Checkpointer( | 811 | checkpointer = Checkpointer( |
812 | weight_dtype=weight_dtype, | ||
801 | datamodule=datamodule, | 813 | datamodule=datamodule, |
802 | accelerator=accelerator, | 814 | accelerator=accelerator, |
803 | vae=vae, | 815 | vae=vae, |