diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
| commit | e2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch) | |
| tree | 574f7a794feab13e1cf0ed18522a66d4737b6db3 /training/util.py | |
| parent | Unified training script structure (diff) | |
| download | textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2 textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip | |
Cleanup
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 26 |
1 files changed, 11 insertions, 15 deletions
diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -44,32 +44,29 @@ class CheckpointerBase: | |||
| 44 | train_dataloader, | 44 | train_dataloader, |
| 45 | val_dataloader, | 45 | val_dataloader, |
| 46 | output_dir: Path, | 46 | output_dir: Path, |
| 47 | sample_image_size: int, | 47 | sample_steps: int = 20, |
| 48 | sample_batches: int, | 48 | sample_guidance_scale: float = 7.5, |
| 49 | sample_batch_size: int, | 49 | sample_image_size: int = 768, |
| 50 | sample_batches: int = 1, | ||
| 51 | sample_batch_size: int = 1, | ||
| 50 | seed: Optional[int] = None | 52 | seed: Optional[int] = None |
| 51 | ): | 53 | ): |
| 52 | self.train_dataloader = train_dataloader | 54 | self.train_dataloader = train_dataloader |
| 53 | self.val_dataloader = val_dataloader | 55 | self.val_dataloader = val_dataloader |
| 54 | self.output_dir = output_dir | 56 | self.output_dir = output_dir |
| 55 | self.sample_image_size = sample_image_size | 57 | self.sample_image_size = sample_image_size |
| 56 | self.seed = seed if seed is not None else torch.random.seed() | 58 | self.sample_steps = sample_steps |
| 59 | self.sample_guidance_scale = sample_guidance_scale | ||
| 57 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
| 58 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
| 62 | self.seed = seed if seed is not None else torch.random.seed() | ||
| 59 | 63 | ||
| 60 | @torch.no_grad() | 64 | @torch.no_grad() |
| 61 | def checkpoint(self, step: int, postfix: str): | 65 | def checkpoint(self, step: int, postfix: str): |
| 62 | pass | 66 | pass |
| 63 | 67 | ||
| 64 | @torch.inference_mode() | 68 | @torch.inference_mode() |
| 65 | def save_samples( | 69 | def save_samples(self, pipeline, step: int): |
| 66 | self, | ||
| 67 | pipeline, | ||
| 68 | step: int, | ||
| 69 | num_inference_steps: int, | ||
| 70 | guidance_scale: float = 7.5, | ||
| 71 | eta: float = 0.0 | ||
| 72 | ): | ||
| 73 | samples_path = Path(self.output_dir).joinpath("samples") | 70 | samples_path = Path(self.output_dir).joinpath("samples") |
| 74 | 71 | ||
| 75 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| @@ -110,9 +107,8 @@ class CheckpointerBase: | |||
| 110 | height=self.sample_image_size, | 107 | height=self.sample_image_size, |
| 111 | width=self.sample_image_size, | 108 | width=self.sample_image_size, |
| 112 | generator=gen, | 109 | generator=gen, |
| 113 | guidance_scale=guidance_scale, | 110 | guidance_scale=self.sample_guidance_scale, |
| 114 | eta=eta, | 111 | num_inference_steps=self.sample_steps, |
| 115 | num_inference_steps=num_inference_steps, | ||
| 116 | output_type='pil' | 112 | output_type='pil' |
| 117 | ).images | 113 | ).images |
| 118 | 114 | ||
