diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 97 | ||||
-rw-r--r-- | training/util.py | 26 |
2 files changed, 63 insertions, 60 deletions
diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -45,42 +45,44 @@ def generate_class_images( | |||
45 | ): | 45 | ): |
46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
47 | 47 | ||
48 | if len(missing_data) != 0: | 48 | if len(missing_data) == 0: |
49 | batched_data = [ | 49 | return |
50 | missing_data[i:i+sample_batch_size] | 50 | |
51 | for i in range(0, len(missing_data), sample_batch_size) | 51 | batched_data = [ |
52 | ] | 52 | missing_data[i:i+sample_batch_size] |
53 | 53 | for i in range(0, len(missing_data), sample_batch_size) | |
54 | pipeline = VlpnStableDiffusion( | 54 | ] |
55 | text_encoder=text_encoder, | 55 | |
56 | vae=vae, | 56 | pipeline = VlpnStableDiffusion( |
57 | unet=unet, | 57 | text_encoder=text_encoder, |
58 | tokenizer=tokenizer, | 58 | vae=vae, |
59 | scheduler=scheduler, | 59 | unet=unet, |
60 | ).to(accelerator.device) | 60 | tokenizer=tokenizer, |
61 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 61 | scheduler=scheduler, |
62 | 62 | ).to(accelerator.device) | |
63 | with torch.inference_mode(): | 63 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
64 | for batch in batched_data: | 64 | |
65 | image_name = [item.class_image_path for item in batch] | 65 | with torch.inference_mode(): |
66 | prompt = [item.cprompt for item in batch] | 66 | for batch in batched_data: |
67 | nprompt = [item.nprompt for item in batch] | 67 | image_name = [item.class_image_path for item in batch] |
68 | 68 | prompt = [item.cprompt for item in batch] | |
69 | images = pipeline( | 69 | nprompt = [item.nprompt for item in batch] |
70 | prompt=prompt, | 70 | |
71 | negative_prompt=nprompt, | 71 | images = pipeline( |
72 | height=sample_image_size, | 72 | prompt=prompt, |
73 | width=sample_image_size, | 73 | negative_prompt=nprompt, |
74 | num_inference_steps=sample_steps | 74 | height=sample_image_size, |
75 | ).images | 75 | width=sample_image_size, |
76 | 76 | num_inference_steps=sample_steps | |
77 | for i, image in enumerate(images): | 77 | ).images |
78 | image.save(image_name[i]) | 78 | |
79 | 79 | for i, image in enumerate(images): | |
80 | del pipeline | 80 | image.save(image_name[i]) |
81 | 81 | ||
82 | if torch.cuda.is_available(): | 82 | del pipeline |
83 | torch.cuda.empty_cache() | 83 | |
84 | if torch.cuda.is_available(): | ||
85 | torch.cuda.empty_cache() | ||
84 | 86 | ||
85 | 87 | ||
86 | def get_models(pretrained_model_name_or_path: str): | 88 | def get_models(pretrained_model_name_or_path: str): |
@@ -119,7 +121,7 @@ def add_placeholder_tokens( | |||
119 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 121 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
120 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 122 | embeddings.add_embed(placeholder_token_id, initializer_token_id) |
121 | 123 | ||
122 | return placeholder_token_ids | 124 | return placeholder_token_ids, initializer_token_ids |
123 | 125 | ||
124 | 126 | ||
125 | def loss_step( | 127 | def loss_step( |
@@ -127,7 +129,6 @@ def loss_step( | |||
127 | noise_scheduler: DDPMScheduler, | 129 | noise_scheduler: DDPMScheduler, |
128 | unet: UNet2DConditionModel, | 130 | unet: UNet2DConditionModel, |
129 | text_encoder: CLIPTextModel, | 131 | text_encoder: CLIPTextModel, |
130 | with_prior: bool, | ||
131 | prior_loss_weight: float, | 132 | prior_loss_weight: float, |
132 | seed: int, | 133 | seed: int, |
133 | step: int, | 134 | step: int, |
@@ -138,16 +139,23 @@ def loss_step( | |||
138 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 139 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
139 | latents = latents * 0.18215 | 140 | latents = latents * 0.18215 |
140 | 141 | ||
142 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
143 | |||
141 | # Sample noise that we'll add to the latents | 144 | # Sample noise that we'll add to the latents |
142 | noise = torch.randn_like(latents) | 145 | noise = torch.randn( |
146 | latents.shape, | ||
147 | dtype=latents.dtype, | ||
148 | layout=latents.layout, | ||
149 | device=latents.device, | ||
150 | generator=generator | ||
151 | ) | ||
143 | bsz = latents.shape[0] | 152 | bsz = latents.shape[0] |
144 | # Sample a random timestep for each image | 153 | # Sample a random timestep for each image |
145 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
146 | timesteps = torch.randint( | 154 | timesteps = torch.randint( |
147 | 0, | 155 | 0, |
148 | noise_scheduler.config.num_train_timesteps, | 156 | noise_scheduler.config.num_train_timesteps, |
149 | (bsz,), | 157 | (bsz,), |
150 | generator=timesteps_gen, | 158 | generator=generator, |
151 | device=latents.device, | 159 | device=latents.device, |
152 | ) | 160 | ) |
153 | timesteps = timesteps.long() | 161 | timesteps = timesteps.long() |
@@ -176,7 +184,7 @@ def loss_step( | |||
176 | else: | 184 | else: |
177 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
178 | 186 | ||
179 | if with_prior: | 187 | if batch["with_prior"]: |
180 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
181 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
182 | target, target_prior = torch.chunk(target, 2, dim=0) | 190 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -207,7 +215,6 @@ def train_loop( | |||
207 | val_dataloader: DataLoader, | 215 | val_dataloader: DataLoader, |
208 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
209 | sample_frequency: int = 10, | 217 | sample_frequency: int = 10, |
210 | sample_steps: int = 20, | ||
211 | checkpoint_frequency: int = 50, | 218 | checkpoint_frequency: int = 50, |
212 | global_step_offset: int = 0, | 219 | global_step_offset: int = 0, |
213 | num_epochs: int = 100, | 220 | num_epochs: int = 100, |
@@ -251,7 +258,7 @@ def train_loop( | |||
251 | for epoch in range(num_epochs): | 258 | for epoch in range(num_epochs): |
252 | if accelerator.is_main_process: | 259 | if accelerator.is_main_process: |
253 | if epoch % sample_frequency == 0: | 260 | if epoch % sample_frequency == 0: |
254 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 261 | checkpointer.save_samples(global_step + global_step_offset) |
255 | 262 | ||
256 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 263 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
257 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 264 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
@@ -353,7 +360,7 @@ def train_loop( | |||
353 | if accelerator.is_main_process: | 360 | if accelerator.is_main_process: |
354 | print("Finished!") | 361 | print("Finished!") |
355 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 362 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
356 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 363 | checkpointer.save_samples(global_step + global_step_offset) |
357 | accelerator.end_training() | 364 | accelerator.end_training() |
358 | 365 | ||
359 | except KeyboardInterrupt: | 366 | except KeyboardInterrupt: |
diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -44,32 +44,29 @@ class CheckpointerBase: | |||
44 | train_dataloader, | 44 | train_dataloader, |
45 | val_dataloader, | 45 | val_dataloader, |
46 | output_dir: Path, | 46 | output_dir: Path, |
47 | sample_image_size: int, | 47 | sample_steps: int = 20, |
48 | sample_batches: int, | 48 | sample_guidance_scale: float = 7.5, |
49 | sample_batch_size: int, | 49 | sample_image_size: int = 768, |
50 | sample_batches: int = 1, | ||
51 | sample_batch_size: int = 1, | ||
50 | seed: Optional[int] = None | 52 | seed: Optional[int] = None |
51 | ): | 53 | ): |
52 | self.train_dataloader = train_dataloader | 54 | self.train_dataloader = train_dataloader |
53 | self.val_dataloader = val_dataloader | 55 | self.val_dataloader = val_dataloader |
54 | self.output_dir = output_dir | 56 | self.output_dir = output_dir |
55 | self.sample_image_size = sample_image_size | 57 | self.sample_image_size = sample_image_size |
56 | self.seed = seed if seed is not None else torch.random.seed() | 58 | self.sample_steps = sample_steps |
59 | self.sample_guidance_scale = sample_guidance_scale | ||
57 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
58 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
62 | self.seed = seed if seed is not None else torch.random.seed() | ||
59 | 63 | ||
60 | @torch.no_grad() | 64 | @torch.no_grad() |
61 | def checkpoint(self, step: int, postfix: str): | 65 | def checkpoint(self, step: int, postfix: str): |
62 | pass | 66 | pass |
63 | 67 | ||
64 | @torch.inference_mode() | 68 | @torch.inference_mode() |
65 | def save_samples( | 69 | def save_samples(self, pipeline, step: int): |
66 | self, | ||
67 | pipeline, | ||
68 | step: int, | ||
69 | num_inference_steps: int, | ||
70 | guidance_scale: float = 7.5, | ||
71 | eta: float = 0.0 | ||
72 | ): | ||
73 | samples_path = Path(self.output_dir).joinpath("samples") | 70 | samples_path = Path(self.output_dir).joinpath("samples") |
74 | 71 | ||
75 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
@@ -110,9 +107,8 @@ class CheckpointerBase: | |||
110 | height=self.sample_image_size, | 107 | height=self.sample_image_size, |
111 | width=self.sample_image_size, | 108 | width=self.sample_image_size, |
112 | generator=gen, | 109 | generator=gen, |
113 | guidance_scale=guidance_scale, | 110 | guidance_scale=self.sample_guidance_scale, |
114 | eta=eta, | 111 | num_inference_steps=self.sample_steps, |
115 | num_inference_steps=num_inference_steps, | ||
116 | output_type='pil' | 112 | output_type='pil' |
117 | ).images | 113 | ).images |
118 | 114 | ||