summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-07 15:05:39 +0100
committerVolpeon <git@volpeon.ink>2023-01-07 15:05:39 +0100
commit6970adaff742ac89adb3d85c803689210dc030e2 (patch)
tree042eec1c77b800c3b64eff4b8cc40f0a7b153e4d /training/util.py
parentAdded progressive aspect ratio bucketing (diff)
downloadtextual-inversion-diff-6970adaff742ac89adb3d85c803689210dc030e2.tar.gz
textual-inversion-diff-6970adaff742ac89adb3d85c803689210dc030e2.tar.bz2
textual-inversion-diff-6970adaff742ac89adb3d85c803689210dc030e2.zip
Made aspect ratio bucketing configurable
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py9
1 files changed, 2 insertions, 7 deletions
diff --git a/training/util.py b/training/util.py
index 6f42228..2b7f71d 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,6 +1,7 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4import itertools
4from typing import Iterable, Optional 5from typing import Iterable, Optional
5from contextlib import contextmanager 6from contextlib import contextmanager
6 7
@@ -71,13 +72,7 @@ class CheckpointerBase:
71 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 72 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
72 file_path.parent.mkdir(parents=True, exist_ok=True) 73 file_path.parent.mkdir(parents=True, exist_ok=True)
73 74
74 data_enum = enumerate(data) 75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
75
76 batches = [
77 batch
78 for j, batch in data_enum
79 if j * data.batch_size < self.sample_batch_size * self.sample_batches
80 ]
81 prompts = [ 76 prompts = [
82 prompt 77 prompt
83 for batch in batches 78 for batch in batches