summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-07 13:57:46 +0100
committerVolpeon <git@volpeon.ink>2023-01-07 13:57:46 +0100
commit3ee13893f9a4973ac75f45fe9318c35760dd4b1f (patch)
treee652a54e6c241eef52ddb30f2d7048da8f306f7b /training/util.py
parentUpdate (diff)
downloadtextual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.gz
textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.bz2
textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.zip
Added progressive aspect ratio bucketing
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/training/util.py b/training/util.py
index bc466e2..6f42228 100644
--- a/training/util.py
+++ b/training/util.py
@@ -58,8 +58,8 @@ class CheckpointerBase:
58 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 58 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
59 samples_path = Path(self.output_dir).joinpath("samples") 59 samples_path = Path(self.output_dir).joinpath("samples")
60 60
61 train_data = self.datamodule.train_dataloader() 61 train_data = self.datamodule.train_dataloaders[0]
62 val_data = self.datamodule.val_dataloader() 62 val_data = self.datamodule.val_dataloader
63 63
64 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 64 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
65 65