summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py26
1 files changed, 11 insertions, 15 deletions
diff --git a/training/util.py b/training/util.py
index cc4cdee..1008021 100644
--- a/training/util.py
+++ b/training/util.py
@@ -44,32 +44,29 @@ class CheckpointerBase:
44 train_dataloader, 44 train_dataloader,
45 val_dataloader, 45 val_dataloader,
46 output_dir: Path, 46 output_dir: Path,
47 sample_image_size: int, 47 sample_steps: int = 20,
48 sample_batches: int, 48 sample_guidance_scale: float = 7.5,
49 sample_batch_size: int, 49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
50 seed: Optional[int] = None 52 seed: Optional[int] = None
51 ): 53 ):
52 self.train_dataloader = train_dataloader 54 self.train_dataloader = train_dataloader
53 self.val_dataloader = val_dataloader 55 self.val_dataloader = val_dataloader
54 self.output_dir = output_dir 56 self.output_dir = output_dir
55 self.sample_image_size = sample_image_size 57 self.sample_image_size = sample_image_size
56 self.seed = seed if seed is not None else torch.random.seed() 58 self.sample_steps = sample_steps
59 self.sample_guidance_scale = sample_guidance_scale
57 self.sample_batches = sample_batches 60 self.sample_batches = sample_batches
58 self.sample_batch_size = sample_batch_size 61 self.sample_batch_size = sample_batch_size
62 self.seed = seed if seed is not None else torch.random.seed()
59 63
60 @torch.no_grad() 64 @torch.no_grad()
61 def checkpoint(self, step: int, postfix: str): 65 def checkpoint(self, step: int, postfix: str):
62 pass 66 pass
63 67
64 @torch.inference_mode() 68 @torch.inference_mode()
65 def save_samples( 69 def save_samples(self, pipeline, step: int):
66 self,
67 pipeline,
68 step: int,
69 num_inference_steps: int,
70 guidance_scale: float = 7.5,
71 eta: float = 0.0
72 ):
73 samples_path = Path(self.output_dir).joinpath("samples") 70 samples_path = Path(self.output_dir).joinpath("samples")
74 71
75 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
@@ -110,9 +107,8 @@ class CheckpointerBase:
110 height=self.sample_image_size, 107 height=self.sample_image_size,
111 width=self.sample_image_size, 108 width=self.sample_image_size,
112 generator=gen, 109 generator=gen,
113 guidance_scale=guidance_scale, 110 guidance_scale=self.sample_guidance_scale,
114 eta=eta, 111 num_inference_steps=self.sample_steps,
115 num_inference_steps=num_inference_steps,
116 output_type='pil' 112 output_type='pil'
117 ).images 113 ).images
118 114