summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-25 23:50:24 +0100
committerVolpeon <git@volpeon.ink>2022-12-25 23:50:24 +0100
commit7505f7e843dc719622a15f4ee301609813763d77 (patch)
treefe67640dce9fec4f625d6d1600c696cd7de006ee /train_dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz
textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2
textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip
Code simplifications, avoid autocast
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e239833..2c765ec 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -389,6 +389,7 @@ def parse_args():
389class Checkpointer(CheckpointerBase): 389class Checkpointer(CheckpointerBase):
390 def __init__( 390 def __init__(
391 self, 391 self,
392 weight_dtype,
392 datamodule, 393 datamodule,
393 accelerator, 394 accelerator,
394 vae, 395 vae,
@@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase):
416 sample_batch_size=sample_batch_size 417 sample_batch_size=sample_batch_size
417 ) 418 )
418 419
420 self.weight_dtype = weight_dtype
419 self.accelerator = accelerator 421 self.accelerator = accelerator
420 self.vae = vae 422 self.vae = vae
421 self.unet = unet 423 self.unet = unet
@@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase):
452 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) 454 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
453 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 455 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
454 456
457 orig_unet_dtype = unet.dtype
458 orig_text_encoder_dtype = text_encoder.dtype
459
460 unet.to(dtype=self.weight_dtype)
461 text_encoder.to(dtype=self.weight_dtype)
462
455 pipeline = VlpnStableDiffusion( 463 pipeline = VlpnStableDiffusion(
456 text_encoder=text_encoder, 464 text_encoder=text_encoder,
457 vae=self.vae, 465 vae=self.vae,
@@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase):
463 471
464 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 472 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
465 473
474 unet.to(dtype=orig_unet_dtype)
475 text_encoder.to(dtype=orig_text_encoder_dtype)
476
466 del unet 477 del unet
467 del text_encoder 478 del text_encoder
468 del pipeline 479 del pipeline
@@ -798,6 +809,7 @@ def main():
798 max_acc_val = 0.0 809 max_acc_val = 0.0
799 810
800 checkpointer = Checkpointer( 811 checkpointer = Checkpointer(
812 weight_dtype=weight_dtype,
801 datamodule=datamodule, 813 datamodule=datamodule,
802 accelerator=accelerator, 814 accelerator=accelerator,
803 vae=vae, 815 vae=vae,