diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
| commit | 7b149930bb53b93db74106ad20a30abf4b114f9b (patch) | |
| tree | 67c2ccbce2a9838ad8a020ee527b19113e67e30a /training/util.py | |
| parent | Added TI decay start offset (diff) | |
| download | textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2 textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip | |
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/training/util.py b/training/util.py index 60d64f0..0ec2032 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -55,8 +55,19 @@ class CheckpointerBase: | |||
| 55 | self.sample_batches = sample_batches | 55 | self.sample_batches = sample_batches |
| 56 | self.sample_batch_size = sample_batch_size | 56 | self.sample_batch_size = sample_batch_size |
| 57 | 57 | ||
| 58 | @torch.no_grad() | ||
| 59 | def checkpoint(self, step: int, postfix: str): | ||
| 60 | pass | ||
| 61 | |||
| 58 | @torch.inference_mode() | 62 | @torch.inference_mode() |
| 59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 63 | def save_samples( |
| 64 | self, | ||
| 65 | pipeline, | ||
| 66 | step: int, | ||
| 67 | num_inference_steps: int, | ||
| 68 | guidance_scale: float = 7.5, | ||
| 69 | eta: float = 0.0 | ||
| 70 | ): | ||
| 60 | samples_path = Path(self.output_dir).joinpath("samples") | 71 | samples_path = Path(self.output_dir).joinpath("samples") |
| 61 | 72 | ||
| 62 | train_data = self.datamodule.train_dataloader | 73 | train_data = self.datamodule.train_dataloader |
