summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-25 23:50:24 +0100
committerVolpeon <git@volpeon.ink>2022-12-25 23:50:24 +0100
commit7505f7e843dc719622a15f4ee301609813763d77 (patch)
treefe67640dce9fec4f625d6d1600c696cd7de006ee /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz
textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2
textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip
Code simplifications, avoid autocast
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 5f37d54..a228795 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -361,6 +361,7 @@ def parse_args():
361class Checkpointer(CheckpointerBase): 361class Checkpointer(CheckpointerBase):
362 def __init__( 362 def __init__(
363 self, 363 self,
364 weight_dtype,
364 datamodule, 365 datamodule,
365 accelerator, 366 accelerator,
366 vae, 367 vae,
@@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase):
387 sample_batch_size=sample_batch_size 388 sample_batch_size=sample_batch_size
388 ) 389 )
389 390
391 self.weight_dtype = weight_dtype
390 self.accelerator = accelerator 392 self.accelerator = accelerator
391 self.vae = vae 393 self.vae = vae
392 self.unet = unet 394 self.unet = unet
@@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase):
417 @torch.no_grad() 419 @torch.no_grad()
418 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 420 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
419 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 421 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
422 orig_dtype = text_encoder.dtype
423 text_encoder.to(dtype=self.weight_dtype)
420 424
421 # Save a sample image
422 pipeline = VlpnStableDiffusion( 425 pipeline = VlpnStableDiffusion(
423 text_encoder=text_encoder, 426 text_encoder=text_encoder,
424 vae=self.vae, 427 vae=self.vae,
@@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase):
430 433
431 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 434 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
432 435
436 text_encoder.to(dtype=orig_dtype)
437
433 del text_encoder 438 del text_encoder
434 del pipeline 439 del pipeline
435 440
@@ -739,6 +744,7 @@ def main():
739 max_acc_val = 0.0 744 max_acc_val = 0.0
740 745
741 checkpointer = Checkpointer( 746 checkpointer = Checkpointer(
747 weight_dtype=weight_dtype,
742 datamodule=datamodule, 748 datamodule=datamodule,
743 accelerator=accelerator, 749 accelerator=accelerator,
744 vae=vae, 750 vae=vae,