diff options
author | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
commit | 7505f7e843dc719622a15f4ee301609813763d77 (patch) | |
tree | fe67640dce9fec4f625d6d1600c696cd7de006ee /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2 textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip |
Code simplifications, avoid autocast
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 5f37d54..a228795 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -361,6 +361,7 @@ def parse_args(): | |||
361 | class Checkpointer(CheckpointerBase): | 361 | class Checkpointer(CheckpointerBase): |
362 | def __init__( | 362 | def __init__( |
363 | self, | 363 | self, |
364 | weight_dtype, | ||
364 | datamodule, | 365 | datamodule, |
365 | accelerator, | 366 | accelerator, |
366 | vae, | 367 | vae, |
@@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase): | |||
387 | sample_batch_size=sample_batch_size | 388 | sample_batch_size=sample_batch_size |
388 | ) | 389 | ) |
389 | 390 | ||
391 | self.weight_dtype = weight_dtype | ||
390 | self.accelerator = accelerator | 392 | self.accelerator = accelerator |
391 | self.vae = vae | 393 | self.vae = vae |
392 | self.unet = unet | 394 | self.unet = unet |
@@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase): | |||
417 | @torch.no_grad() | 419 | @torch.no_grad() |
418 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 420 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
419 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 421 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
422 | orig_dtype = text_encoder.dtype | ||
423 | text_encoder.to(dtype=self.weight_dtype) | ||
420 | 424 | ||
421 | # Save a sample image | ||
422 | pipeline = VlpnStableDiffusion( | 425 | pipeline = VlpnStableDiffusion( |
423 | text_encoder=text_encoder, | 426 | text_encoder=text_encoder, |
424 | vae=self.vae, | 427 | vae=self.vae, |
@@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase): | |||
430 | 433 | ||
431 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 434 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
432 | 435 | ||
436 | text_encoder.to(dtype=orig_dtype) | ||
437 | |||
433 | del text_encoder | 438 | del text_encoder |
434 | del pipeline | 439 | del pipeline |
435 | 440 | ||
@@ -739,6 +744,7 @@ def main(): | |||
739 | max_acc_val = 0.0 | 744 | max_acc_val = 0.0 |
740 | 745 | ||
741 | checkpointer = Checkpointer( | 746 | checkpointer = Checkpointer( |
747 | weight_dtype=weight_dtype, | ||
742 | datamodule=datamodule, | 748 | datamodule=datamodule, |
743 | accelerator=accelerator, | 749 | accelerator=accelerator, |
744 | vae=vae, | 750 | vae=vae, |