From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- training/common.py | 97 +++++++++++++++++++++++++++++------------------------- training/util.py | 26 +++++++-------- 2 files changed, 63 insertions(+), 60 deletions(-) (limited to 'training') diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py @@ -45,42 +45,44 @@ def generate_class_images( ): missing_data = [item for item in data_train if not item.class_image_path.exists()] - if len(missing_data) != 0: - batched_data = [ - missing_data[i:i+sample_batch_size] - for i in range(0, len(missing_data), sample_batch_size) - ] - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ).to(accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - with torch.inference_mode(): - for batch in batched_data: - image_name = [item.class_image_path for item in batch] - prompt = [item.cprompt for item in batch] - nprompt = [item.nprompt for item in batch] - - images = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=sample_image_size, - width=sample_image_size, - num_inference_steps=sample_steps - ).images - - for i, image in enumerate(images): - image.save(image_name[i]) - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if len(missing_data) == 0: + return + + batched_data = [ + missing_data[i:i+sample_batch_size] + for i in range(0, len(missing_data), sample_batch_size) + ] + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + with torch.inference_mode(): + for batch in batched_data: + image_name = [item.class_image_path for item in batch] + prompt = [item.cprompt for item in batch] + nprompt = [item.nprompt for item in batch] + + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=sample_image_size, + width=sample_image_size, + num_inference_steps=sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_models(pretrained_model_name_or_path: str): @@ -119,7 +121,7 @@ def add_placeholder_tokens( for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): embeddings.add_embed(placeholder_token_id, initializer_token_id) - return placeholder_token_ids + return placeholder_token_ids, initializer_token_ids def loss_step( @@ -127,7 +129,6 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - with_prior: bool, prior_loss_weight: float, seed: int, step: int, @@ -138,16 +139,23 @@ def loss_step( latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 + generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None + # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + noise = torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), - generator=timesteps_gen, + generator=generator, device=latents.device, ) timesteps = timesteps.long() @@ -176,7 +184,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if with_prior: + if batch["with_prior"]: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -207,7 +215,6 @@ def train_loop( val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], sample_frequency: int = 10, - sample_steps: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, num_epochs: int = 100, @@ -251,7 +258,7 @@ def train_loop( for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset, sample_steps) + checkpointer.save_samples(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: checkpointer.checkpoint(global_step + global_step_offset, "training") @@ -353,7 +360,7 @@ def train_loop( if accelerator.is_main_process: print("Finished!") checkpointer.checkpoint(global_step + global_step_offset, "end") - checkpointer.save_samples(global_step + global_step_offset, sample_steps) + checkpointer.save_samples(global_step + global_step_offset) accelerator.end_training() except KeyboardInterrupt: diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py @@ -44,32 +44,29 @@ class CheckpointerBase: train_dataloader, val_dataloader, output_dir: Path, - sample_image_size: int, - sample_batches: int, - sample_batch_size: int, + sample_steps: int = 20, + sample_guidance_scale: float = 7.5, + sample_image_size: int = 768, + sample_batches: int = 1, + sample_batch_size: int = 1, seed: Optional[int] = None ): self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.output_dir = output_dir self.sample_image_size = sample_image_size - self.seed = seed if seed is not None else torch.random.seed() + self.sample_steps = sample_steps + self.sample_guidance_scale = sample_guidance_scale self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + self.seed = seed if seed is not None else torch.random.seed() @torch.no_grad() def checkpoint(self, step: int, postfix: str): pass @torch.inference_mode() - def save_samples( - self, - pipeline, - step: int, - num_inference_steps: int, - guidance_scale: float = 7.5, - eta: float = 0.0 - ): + def save_samples(self, pipeline, step: int): samples_path = Path(self.output_dir).joinpath("samples") generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) @@ -110,9 +107,8 @@ class CheckpointerBase: height=self.sample_image_size, width=self.sample_image_size, generator=gen, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, + guidance_scale=self.sample_guidance_scale, + num_inference_steps=self.sample_steps, output_type='pil' ).images -- cgit v1.2.3-54-g00ecf