summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
commit7b149930bb53b93db74106ad20a30abf4b114f9b (patch)
tree67c2ccbce2a9838ad8a020ee527b19113e67e30a /training/util.py
parentAdded TI decay start offset (diff)
downloadtextual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/training/util.py b/training/util.py
index 60d64f0..0ec2032 100644
--- a/training/util.py
+++ b/training/util.py
@@ -55,8 +55,19 @@ class CheckpointerBase:
55 self.sample_batches = sample_batches 55 self.sample_batches = sample_batches
56 self.sample_batch_size = sample_batch_size 56 self.sample_batch_size = sample_batch_size
57 57
58 @torch.no_grad()
59 def checkpoint(self, step: int, postfix: str):
60 pass
61
58 @torch.inference_mode() 62 @torch.inference_mode()
59 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 63 def save_samples(
64 self,
65 pipeline,
66 step: int,
67 num_inference_steps: int,
68 guidance_scale: float = 7.5,
69 eta: float = 0.0
70 ):
60 samples_path = Path(self.output_dir).joinpath("samples") 71 samples_path = Path(self.output_dir).joinpath("samples")
61 72
62 train_data = self.datamodule.train_dataloader 73 train_data = self.datamodule.train_dataloader