summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
commit127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch)
tree61cb98adbf33ed08506601f8b70f1b62bc42c4ee /training/util.py
parentSimplified step calculations (diff)
downloadtextual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip
More modularization
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py
index 0ec2032..cc4cdee 100644
--- a/training/util.py
+++ b/training/util.py
@@ -41,14 +41,16 @@ class AverageMeter:
41class CheckpointerBase: 41class CheckpointerBase:
42 def __init__( 42 def __init__(
43 self, 43 self,
44 datamodule, 44 train_dataloader,
45 val_dataloader,
45 output_dir: Path, 46 output_dir: Path,
46 sample_image_size: int, 47 sample_image_size: int,
47 sample_batches: int, 48 sample_batches: int,
48 sample_batch_size: int, 49 sample_batch_size: int,
49 seed: Optional[int] = None 50 seed: Optional[int] = None
50 ): 51 ):
51 self.datamodule = datamodule 52 self.train_dataloader = train_dataloader
53 self.val_dataloader = val_dataloader
52 self.output_dir = output_dir 54 self.output_dir = output_dir
53 self.sample_image_size = sample_image_size 55 self.sample_image_size = sample_image_size
54 self.seed = seed if seed is not None else torch.random.seed() 56 self.seed = seed if seed is not None else torch.random.seed()
@@ -70,15 +72,16 @@ class CheckpointerBase:
70 ): 72 ):
71 samples_path = Path(self.output_dir).joinpath("samples") 73 samples_path = Path(self.output_dir).joinpath("samples")
72 74
73 train_data = self.datamodule.train_dataloader
74 val_data = self.datamodule.val_dataloader
75
76 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 75 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
77 76
78 grid_cols = min(self.sample_batch_size, 4) 77 grid_cols = min(self.sample_batch_size, 4)
79 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols 78 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
80 79
81 for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: 80 for pool, data, gen in [
81 ("stable", self.val_dataloader, generator),
82 ("val", self.val_dataloader, None),
83 ("train", self.train_dataloader, None)
84 ]:
82 all_samples = [] 85 all_samples = []
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 86 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
84 file_path.parent.mkdir(parents=True, exist_ok=True) 87 file_path.parent.mkdir(parents=True, exist_ok=True)