diff options
Diffstat (limited to 'training/common.py')
| -rw-r--r-- | training/common.py | 85 |
1 files changed, 46 insertions, 39 deletions
diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -45,42 +45,44 @@ def generate_class_images( | |||
| 45 | ): | 45 | ): |
| 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
| 47 | 47 | ||
| 48 | if len(missing_data) != 0: | 48 | if len(missing_data) == 0: |
| 49 | batched_data = [ | 49 | return |
| 50 | missing_data[i:i+sample_batch_size] | ||
| 51 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 52 | ] | ||
| 53 | 50 | ||
| 54 | pipeline = VlpnStableDiffusion( | 51 | batched_data = [ |
| 55 | text_encoder=text_encoder, | 52 | missing_data[i:i+sample_batch_size] |
| 56 | vae=vae, | 53 | for i in range(0, len(missing_data), sample_batch_size) |
| 57 | unet=unet, | 54 | ] |
| 58 | tokenizer=tokenizer, | 55 | |
| 59 | scheduler=scheduler, | 56 | pipeline = VlpnStableDiffusion( |
| 60 | ).to(accelerator.device) | 57 | text_encoder=text_encoder, |
| 61 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 58 | vae=vae, |
| 59 | unet=unet, | ||
| 60 | tokenizer=tokenizer, | ||
| 61 | scheduler=scheduler, | ||
| 62 | ).to(accelerator.device) | ||
| 63 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 62 | 64 | ||
| 63 | with torch.inference_mode(): | 65 | with torch.inference_mode(): |
| 64 | for batch in batched_data: | 66 | for batch in batched_data: |
| 65 | image_name = [item.class_image_path for item in batch] | 67 | image_name = [item.class_image_path for item in batch] |
| 66 | prompt = [item.cprompt for item in batch] | 68 | prompt = [item.cprompt for item in batch] |
| 67 | nprompt = [item.nprompt for item in batch] | 69 | nprompt = [item.nprompt for item in batch] |
| 68 | 70 | ||
| 69 | images = pipeline( | 71 | images = pipeline( |
| 70 | prompt=prompt, | 72 | prompt=prompt, |
| 71 | negative_prompt=nprompt, | 73 | negative_prompt=nprompt, |
| 72 | height=sample_image_size, | 74 | height=sample_image_size, |
| 73 | width=sample_image_size, | 75 | width=sample_image_size, |
| 74 | num_inference_steps=sample_steps | 76 | num_inference_steps=sample_steps |
| 75 | ).images | 77 | ).images |
| 76 | 78 | ||
| 77 | for i, image in enumerate(images): | 79 | for i, image in enumerate(images): |
| 78 | image.save(image_name[i]) | 80 | image.save(image_name[i]) |
| 79 | 81 | ||
| 80 | del pipeline | 82 | del pipeline |
| 81 | 83 | ||
| 82 | if torch.cuda.is_available(): | 84 | if torch.cuda.is_available(): |
| 83 | torch.cuda.empty_cache() | 85 | torch.cuda.empty_cache() |
| 84 | 86 | ||
| 85 | 87 | ||
| 86 | def get_models(pretrained_model_name_or_path: str): | 88 | def get_models(pretrained_model_name_or_path: str): |
| @@ -119,7 +121,7 @@ def add_placeholder_tokens( | |||
| 119 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 121 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
| 120 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 122 | embeddings.add_embed(placeholder_token_id, initializer_token_id) |
| 121 | 123 | ||
| 122 | return placeholder_token_ids | 124 | return placeholder_token_ids, initializer_token_ids |
| 123 | 125 | ||
| 124 | 126 | ||
| 125 | def loss_step( | 127 | def loss_step( |
| @@ -127,7 +129,6 @@ def loss_step( | |||
| 127 | noise_scheduler: DDPMScheduler, | 129 | noise_scheduler: DDPMScheduler, |
| 128 | unet: UNet2DConditionModel, | 130 | unet: UNet2DConditionModel, |
| 129 | text_encoder: CLIPTextModel, | 131 | text_encoder: CLIPTextModel, |
| 130 | with_prior: bool, | ||
| 131 | prior_loss_weight: float, | 132 | prior_loss_weight: float, |
| 132 | seed: int, | 133 | seed: int, |
| 133 | step: int, | 134 | step: int, |
| @@ -138,16 +139,23 @@ def loss_step( | |||
| 138 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 139 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| 139 | latents = latents * 0.18215 | 140 | latents = latents * 0.18215 |
| 140 | 141 | ||
| 142 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 143 | |||
| 141 | # Sample noise that we'll add to the latents | 144 | # Sample noise that we'll add to the latents |
| 142 | noise = torch.randn_like(latents) | 145 | noise = torch.randn( |
| 146 | latents.shape, | ||
| 147 | dtype=latents.dtype, | ||
| 148 | layout=latents.layout, | ||
| 149 | device=latents.device, | ||
| 150 | generator=generator | ||
| 151 | ) | ||
| 143 | bsz = latents.shape[0] | 152 | bsz = latents.shape[0] |
| 144 | # Sample a random timestep for each image | 153 | # Sample a random timestep for each image |
| 145 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 146 | timesteps = torch.randint( | 154 | timesteps = torch.randint( |
| 147 | 0, | 155 | 0, |
| 148 | noise_scheduler.config.num_train_timesteps, | 156 | noise_scheduler.config.num_train_timesteps, |
| 149 | (bsz,), | 157 | (bsz,), |
| 150 | generator=timesteps_gen, | 158 | generator=generator, |
| 151 | device=latents.device, | 159 | device=latents.device, |
| 152 | ) | 160 | ) |
| 153 | timesteps = timesteps.long() | 161 | timesteps = timesteps.long() |
| @@ -176,7 +184,7 @@ def loss_step( | |||
| 176 | else: | 184 | else: |
| 177 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 178 | 186 | ||
| 179 | if with_prior: | 187 | if batch["with_prior"]: |
| 180 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 181 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 182 | target, target_prior = torch.chunk(target, 2, dim=0) | 190 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -207,7 +215,6 @@ def train_loop( | |||
| 207 | val_dataloader: DataLoader, | 215 | val_dataloader: DataLoader, |
| 208 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 209 | sample_frequency: int = 10, | 217 | sample_frequency: int = 10, |
| 210 | sample_steps: int = 20, | ||
| 211 | checkpoint_frequency: int = 50, | 218 | checkpoint_frequency: int = 50, |
| 212 | global_step_offset: int = 0, | 219 | global_step_offset: int = 0, |
| 213 | num_epochs: int = 100, | 220 | num_epochs: int = 100, |
| @@ -251,7 +258,7 @@ def train_loop( | |||
| 251 | for epoch in range(num_epochs): | 258 | for epoch in range(num_epochs): |
| 252 | if accelerator.is_main_process: | 259 | if accelerator.is_main_process: |
| 253 | if epoch % sample_frequency == 0: | 260 | if epoch % sample_frequency == 0: |
| 254 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 261 | checkpointer.save_samples(global_step + global_step_offset) |
| 255 | 262 | ||
| 256 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 263 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 257 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 264 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
| @@ -353,7 +360,7 @@ def train_loop( | |||
| 353 | if accelerator.is_main_process: | 360 | if accelerator.is_main_process: |
| 354 | print("Finished!") | 361 | print("Finished!") |
| 355 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 362 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 356 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 363 | checkpointer.save_samples(global_step + global_step_offset) |
| 357 | accelerator.end_training() | 364 | accelerator.end_training() |
| 358 | 365 | ||
| 359 | except KeyboardInterrupt: | 366 | except KeyboardInterrupt: |
