diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 18:59:26 +0100 |
commit | 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch) | |
tree | 61cb98adbf33ed08506601f8b70f1b62bc42c4ee /training/util.py | |
parent | Simplified step calculations (diff) | |
download | textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2 textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip |
More modularization
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/training/util.py b/training/util.py index 0ec2032..cc4cdee 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -41,14 +41,16 @@ class AverageMeter: | |||
41 | class CheckpointerBase: | 41 | class CheckpointerBase: |
42 | def __init__( | 42 | def __init__( |
43 | self, | 43 | self, |
44 | datamodule, | 44 | train_dataloader, |
45 | val_dataloader, | ||
45 | output_dir: Path, | 46 | output_dir: Path, |
46 | sample_image_size: int, | 47 | sample_image_size: int, |
47 | sample_batches: int, | 48 | sample_batches: int, |
48 | sample_batch_size: int, | 49 | sample_batch_size: int, |
49 | seed: Optional[int] = None | 50 | seed: Optional[int] = None |
50 | ): | 51 | ): |
51 | self.datamodule = datamodule | 52 | self.train_dataloader = train_dataloader |
53 | self.val_dataloader = val_dataloader | ||
52 | self.output_dir = output_dir | 54 | self.output_dir = output_dir |
53 | self.sample_image_size = sample_image_size | 55 | self.sample_image_size = sample_image_size |
54 | self.seed = seed if seed is not None else torch.random.seed() | 56 | self.seed = seed if seed is not None else torch.random.seed() |
@@ -70,15 +72,16 @@ class CheckpointerBase: | |||
70 | ): | 72 | ): |
71 | samples_path = Path(self.output_dir).joinpath("samples") | 73 | samples_path = Path(self.output_dir).joinpath("samples") |
72 | 74 | ||
73 | train_data = self.datamodule.train_dataloader | ||
74 | val_data = self.datamodule.val_dataloader | ||
75 | |||
76 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 75 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
77 | 76 | ||
78 | grid_cols = min(self.sample_batch_size, 4) | 77 | grid_cols = min(self.sample_batch_size, 4) |
79 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | 78 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols |
80 | 79 | ||
81 | for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: | 80 | for pool, data, gen in [ |
81 | ("stable", self.val_dataloader, generator), | ||
82 | ("val", self.val_dataloader, None), | ||
83 | ("train", self.train_dataloader, None) | ||
84 | ]: | ||
82 | all_samples = [] | 85 | all_samples = [] |
83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 86 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
84 | file_path.parent.mkdir(parents=True, exist_ok=True) | 87 | file_path.parent.mkdir(parents=True, exist_ok=True) |