diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/common.py | 85 | ||||
| -rw-r--r-- | training/util.py | 26 |
2 files changed, 57 insertions, 54 deletions
diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -45,42 +45,44 @@ def generate_class_images( | |||
| 45 | ): | 45 | ): |
| 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
| 47 | 47 | ||
| 48 | if len(missing_data) != 0: | 48 | if len(missing_data) == 0: |
| 49 | batched_data = [ | 49 | return |
| 50 | missing_data[i:i+sample_batch_size] | ||
| 51 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 52 | ] | ||
| 53 | 50 | ||
| 54 | pipeline = VlpnStableDiffusion( | 51 | batched_data = [ |
| 55 | text_encoder=text_encoder, | 52 | missing_data[i:i+sample_batch_size] |
| 56 | vae=vae, | 53 | for i in range(0, len(missing_data), sample_batch_size) |
| 57 | unet=unet, | 54 | ] |
| 58 | tokenizer=tokenizer, | 55 | |
| 59 | scheduler=scheduler, | 56 | pipeline = VlpnStableDiffusion( |
| 60 | ).to(accelerator.device) | 57 | text_encoder=text_encoder, |
| 61 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 58 | vae=vae, |
| 59 | unet=unet, | ||
| 60 | tokenizer=tokenizer, | ||
| 61 | scheduler=scheduler, | ||
| 62 | ).to(accelerator.device) | ||
| 63 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 62 | 64 | ||
| 63 | with torch.inference_mode(): | 65 | with torch.inference_mode(): |
| 64 | for batch in batched_data: | 66 | for batch in batched_data: |
| 65 | image_name = [item.class_image_path for item in batch] | 67 | image_name = [item.class_image_path for item in batch] |
| 66 | prompt = [item.cprompt for item in batch] | 68 | prompt = [item.cprompt for item in batch] |
| 67 | nprompt = [item.nprompt for item in batch] | 69 | nprompt = [item.nprompt for item in batch] |
| 68 | 70 | ||
| 69 | images = pipeline( | 71 | images = pipeline( |
| 70 | prompt=prompt, | 72 | prompt=prompt, |
| 71 | negative_prompt=nprompt, | 73 | negative_prompt=nprompt, |
| 72 | height=sample_image_size, | 74 | height=sample_image_size, |
| 73 | width=sample_image_size, | 75 | width=sample_image_size, |
| 74 | num_inference_steps=sample_steps | 76 | num_inference_steps=sample_steps |
| 75 | ).images | 77 | ).images |
| 76 | 78 | ||
| 77 | for i, image in enumerate(images): | 79 | for i, image in enumerate(images): |
| 78 | image.save(image_name[i]) | 80 | image.save(image_name[i]) |
| 79 | 81 | ||
| 80 | del pipeline | 82 | del pipeline |
| 81 | 83 | ||
| 82 | if torch.cuda.is_available(): | 84 | if torch.cuda.is_available(): |
| 83 | torch.cuda.empty_cache() | 85 | torch.cuda.empty_cache() |
| 84 | 86 | ||
| 85 | 87 | ||
| 86 | def get_models(pretrained_model_name_or_path: str): | 88 | def get_models(pretrained_model_name_or_path: str): |
| @@ -119,7 +121,7 @@ def add_placeholder_tokens( | |||
| 119 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 121 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
| 120 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 122 | embeddings.add_embed(placeholder_token_id, initializer_token_id) |
| 121 | 123 | ||
| 122 | return placeholder_token_ids | 124 | return placeholder_token_ids, initializer_token_ids |
| 123 | 125 | ||
| 124 | 126 | ||
| 125 | def loss_step( | 127 | def loss_step( |
| @@ -127,7 +129,6 @@ def loss_step( | |||
| 127 | noise_scheduler: DDPMScheduler, | 129 | noise_scheduler: DDPMScheduler, |
| 128 | unet: UNet2DConditionModel, | 130 | unet: UNet2DConditionModel, |
| 129 | text_encoder: CLIPTextModel, | 131 | text_encoder: CLIPTextModel, |
| 130 | with_prior: bool, | ||
| 131 | prior_loss_weight: float, | 132 | prior_loss_weight: float, |
| 132 | seed: int, | 133 | seed: int, |
| 133 | step: int, | 134 | step: int, |
| @@ -138,16 +139,23 @@ def loss_step( | |||
| 138 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 139 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| 139 | latents = latents * 0.18215 | 140 | latents = latents * 0.18215 |
| 140 | 141 | ||
| 142 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 143 | |||
| 141 | # Sample noise that we'll add to the latents | 144 | # Sample noise that we'll add to the latents |
| 142 | noise = torch.randn_like(latents) | 145 | noise = torch.randn( |
| 146 | latents.shape, | ||
| 147 | dtype=latents.dtype, | ||
| 148 | layout=latents.layout, | ||
| 149 | device=latents.device, | ||
| 150 | generator=generator | ||
| 151 | ) | ||
| 143 | bsz = latents.shape[0] | 152 | bsz = latents.shape[0] |
| 144 | # Sample a random timestep for each image | 153 | # Sample a random timestep for each image |
| 145 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 146 | timesteps = torch.randint( | 154 | timesteps = torch.randint( |
| 147 | 0, | 155 | 0, |
| 148 | noise_scheduler.config.num_train_timesteps, | 156 | noise_scheduler.config.num_train_timesteps, |
| 149 | (bsz,), | 157 | (bsz,), |
| 150 | generator=timesteps_gen, | 158 | generator=generator, |
| 151 | device=latents.device, | 159 | device=latents.device, |
| 152 | ) | 160 | ) |
| 153 | timesteps = timesteps.long() | 161 | timesteps = timesteps.long() |
| @@ -176,7 +184,7 @@ def loss_step( | |||
| 176 | else: | 184 | else: |
| 177 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 178 | 186 | ||
| 179 | if with_prior: | 187 | if batch["with_prior"]: |
| 180 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 181 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 182 | target, target_prior = torch.chunk(target, 2, dim=0) | 190 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -207,7 +215,6 @@ def train_loop( | |||
| 207 | val_dataloader: DataLoader, | 215 | val_dataloader: DataLoader, |
| 208 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 209 | sample_frequency: int = 10, | 217 | sample_frequency: int = 10, |
| 210 | sample_steps: int = 20, | ||
| 211 | checkpoint_frequency: int = 50, | 218 | checkpoint_frequency: int = 50, |
| 212 | global_step_offset: int = 0, | 219 | global_step_offset: int = 0, |
| 213 | num_epochs: int = 100, | 220 | num_epochs: int = 100, |
| @@ -251,7 +258,7 @@ def train_loop( | |||
| 251 | for epoch in range(num_epochs): | 258 | for epoch in range(num_epochs): |
| 252 | if accelerator.is_main_process: | 259 | if accelerator.is_main_process: |
| 253 | if epoch % sample_frequency == 0: | 260 | if epoch % sample_frequency == 0: |
| 254 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 261 | checkpointer.save_samples(global_step + global_step_offset) |
| 255 | 262 | ||
| 256 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 263 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 257 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 264 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
| @@ -353,7 +360,7 @@ def train_loop( | |||
| 353 | if accelerator.is_main_process: | 360 | if accelerator.is_main_process: |
| 354 | print("Finished!") | 361 | print("Finished!") |
| 355 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 362 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 356 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 363 | checkpointer.save_samples(global_step + global_step_offset) |
| 357 | accelerator.end_training() | 364 | accelerator.end_training() |
| 358 | 365 | ||
| 359 | except KeyboardInterrupt: | 366 | except KeyboardInterrupt: |
diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -44,32 +44,29 @@ class CheckpointerBase: | |||
| 44 | train_dataloader, | 44 | train_dataloader, |
| 45 | val_dataloader, | 45 | val_dataloader, |
| 46 | output_dir: Path, | 46 | output_dir: Path, |
| 47 | sample_image_size: int, | 47 | sample_steps: int = 20, |
| 48 | sample_batches: int, | 48 | sample_guidance_scale: float = 7.5, |
| 49 | sample_batch_size: int, | 49 | sample_image_size: int = 768, |
| 50 | sample_batches: int = 1, | ||
| 51 | sample_batch_size: int = 1, | ||
| 50 | seed: Optional[int] = None | 52 | seed: Optional[int] = None |
| 51 | ): | 53 | ): |
| 52 | self.train_dataloader = train_dataloader | 54 | self.train_dataloader = train_dataloader |
| 53 | self.val_dataloader = val_dataloader | 55 | self.val_dataloader = val_dataloader |
| 54 | self.output_dir = output_dir | 56 | self.output_dir = output_dir |
| 55 | self.sample_image_size = sample_image_size | 57 | self.sample_image_size = sample_image_size |
| 56 | self.seed = seed if seed is not None else torch.random.seed() | 58 | self.sample_steps = sample_steps |
| 59 | self.sample_guidance_scale = sample_guidance_scale | ||
| 57 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
| 58 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
| 62 | self.seed = seed if seed is not None else torch.random.seed() | ||
| 59 | 63 | ||
| 60 | @torch.no_grad() | 64 | @torch.no_grad() |
| 61 | def checkpoint(self, step: int, postfix: str): | 65 | def checkpoint(self, step: int, postfix: str): |
| 62 | pass | 66 | pass |
| 63 | 67 | ||
| 64 | @torch.inference_mode() | 68 | @torch.inference_mode() |
| 65 | def save_samples( | 69 | def save_samples(self, pipeline, step: int): |
| 66 | self, | ||
| 67 | pipeline, | ||
| 68 | step: int, | ||
| 69 | num_inference_steps: int, | ||
| 70 | guidance_scale: float = 7.5, | ||
| 71 | eta: float = 0.0 | ||
| 72 | ): | ||
| 73 | samples_path = Path(self.output_dir).joinpath("samples") | 70 | samples_path = Path(self.output_dir).joinpath("samples") |
| 74 | 71 | ||
| 75 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| @@ -110,9 +107,8 @@ class CheckpointerBase: | |||
| 110 | height=self.sample_image_size, | 107 | height=self.sample_image_size, |
| 111 | width=self.sample_image_size, | 108 | width=self.sample_image_size, |
| 112 | generator=gen, | 109 | generator=gen, |
| 113 | guidance_scale=guidance_scale, | 110 | guidance_scale=self.sample_guidance_scale, |
| 114 | eta=eta, | 111 | num_inference_steps=self.sample_steps, |
| 115 | num_inference_steps=num_inference_steps, | ||
| 116 | output_type='pil' | 112 | output_type='pil' |
| 117 | ).images | 113 | ).images |
| 118 | 114 | ||
