diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
commit | e2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch) | |
tree | 574f7a794feab13e1cf0ed18522a66d4737b6db3 /training/util.py | |
parent | Unified training script structure (diff) | |
download | textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2 textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip |
Cleanup
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 26 |
1 files changed, 11 insertions, 15 deletions
diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -44,32 +44,29 @@ class CheckpointerBase: | |||
44 | train_dataloader, | 44 | train_dataloader, |
45 | val_dataloader, | 45 | val_dataloader, |
46 | output_dir: Path, | 46 | output_dir: Path, |
47 | sample_image_size: int, | 47 | sample_steps: int = 20, |
48 | sample_batches: int, | 48 | sample_guidance_scale: float = 7.5, |
49 | sample_batch_size: int, | 49 | sample_image_size: int = 768, |
50 | sample_batches: int = 1, | ||
51 | sample_batch_size: int = 1, | ||
50 | seed: Optional[int] = None | 52 | seed: Optional[int] = None |
51 | ): | 53 | ): |
52 | self.train_dataloader = train_dataloader | 54 | self.train_dataloader = train_dataloader |
53 | self.val_dataloader = val_dataloader | 55 | self.val_dataloader = val_dataloader |
54 | self.output_dir = output_dir | 56 | self.output_dir = output_dir |
55 | self.sample_image_size = sample_image_size | 57 | self.sample_image_size = sample_image_size |
56 | self.seed = seed if seed is not None else torch.random.seed() | 58 | self.sample_steps = sample_steps |
59 | self.sample_guidance_scale = sample_guidance_scale | ||
57 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
58 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
62 | self.seed = seed if seed is not None else torch.random.seed() | ||
59 | 63 | ||
60 | @torch.no_grad() | 64 | @torch.no_grad() |
61 | def checkpoint(self, step: int, postfix: str): | 65 | def checkpoint(self, step: int, postfix: str): |
62 | pass | 66 | pass |
63 | 67 | ||
64 | @torch.inference_mode() | 68 | @torch.inference_mode() |
65 | def save_samples( | 69 | def save_samples(self, pipeline, step: int): |
66 | self, | ||
67 | pipeline, | ||
68 | step: int, | ||
69 | num_inference_steps: int, | ||
70 | guidance_scale: float = 7.5, | ||
71 | eta: float = 0.0 | ||
72 | ): | ||
73 | samples_path = Path(self.output_dir).joinpath("samples") | 70 | samples_path = Path(self.output_dir).joinpath("samples") |
74 | 71 | ||
75 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
@@ -110,9 +107,8 @@ class CheckpointerBase: | |||
110 | height=self.sample_image_size, | 107 | height=self.sample_image_size, |
111 | width=self.sample_image_size, | 108 | width=self.sample_image_size, |
112 | generator=gen, | 109 | generator=gen, |
113 | guidance_scale=guidance_scale, | 110 | guidance_scale=self.sample_guidance_scale, |
114 | eta=eta, | 111 | num_inference_steps=self.sample_steps, |
115 | num_inference_steps=num_inference_steps, | ||
116 | output_type='pil' | 112 | output_type='pil' |
117 | ).images | 113 | ).images |
118 | 114 | ||