diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 | 
| commit | 7505f7e843dc719622a15f4ee301609813763d77 (patch) | |
| tree | fe67640dce9fec4f625d6d1600c696cd7de006ee /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2 textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip | |
Code simplifications, avoid autocast
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 8 | 
1 files changed, 7 insertions, 1 deletions
| diff --git a/train_ti.py b/train_ti.py index 5f37d54..a228795 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -361,6 +361,7 @@ def parse_args(): | |||
| 361 | class Checkpointer(CheckpointerBase): | 361 | class Checkpointer(CheckpointerBase): | 
| 362 | def __init__( | 362 | def __init__( | 
| 363 | self, | 363 | self, | 
| 364 | weight_dtype, | ||
| 364 | datamodule, | 365 | datamodule, | 
| 365 | accelerator, | 366 | accelerator, | 
| 366 | vae, | 367 | vae, | 
| @@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase): | |||
| 387 | sample_batch_size=sample_batch_size | 388 | sample_batch_size=sample_batch_size | 
| 388 | ) | 389 | ) | 
| 389 | 390 | ||
| 391 | self.weight_dtype = weight_dtype | ||
| 390 | self.accelerator = accelerator | 392 | self.accelerator = accelerator | 
| 391 | self.vae = vae | 393 | self.vae = vae | 
| 392 | self.unet = unet | 394 | self.unet = unet | 
| @@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase): | |||
| 417 | @torch.no_grad() | 419 | @torch.no_grad() | 
| 418 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 420 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 
| 419 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 421 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 
| 422 | orig_dtype = text_encoder.dtype | ||
| 423 | text_encoder.to(dtype=self.weight_dtype) | ||
| 420 | 424 | ||
| 421 | # Save a sample image | ||
| 422 | pipeline = VlpnStableDiffusion( | 425 | pipeline = VlpnStableDiffusion( | 
| 423 | text_encoder=text_encoder, | 426 | text_encoder=text_encoder, | 
| 424 | vae=self.vae, | 427 | vae=self.vae, | 
| @@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase): | |||
| 430 | 433 | ||
| 431 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 434 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 
| 432 | 435 | ||
| 436 | text_encoder.to(dtype=orig_dtype) | ||
| 437 | |||
| 433 | del text_encoder | 438 | del text_encoder | 
| 434 | del pipeline | 439 | del pipeline | 
| 435 | 440 | ||
| @@ -739,6 +744,7 @@ def main(): | |||
| 739 | max_acc_val = 0.0 | 744 | max_acc_val = 0.0 | 
| 740 | 745 | ||
| 741 | checkpointer = Checkpointer( | 746 | checkpointer = Checkpointer( | 
| 747 | weight_dtype=weight_dtype, | ||
| 742 | datamodule=datamodule, | 748 | datamodule=datamodule, | 
| 743 | accelerator=accelerator, | 749 | accelerator=accelerator, | 
| 744 | vae=vae, | 750 | vae=vae, | 
