From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- training/common.py | 97 +++++++++++++++++++++++++++++------------------------- 1 file changed, 52 insertions(+), 45 deletions(-) (limited to 'training/common.py') diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py @@ -45,42 +45,44 @@ def generate_class_images( ): missing_data = [item for item in data_train if not item.class_image_path.exists()] - if len(missing_data) != 0: - batched_data = [ - missing_data[i:i+sample_batch_size] - for i in range(0, len(missing_data), sample_batch_size) - ] - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ).to(accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - with torch.inference_mode(): - for batch in batched_data: - image_name = [item.class_image_path for item in batch] - prompt = [item.cprompt for item in batch] - nprompt = [item.nprompt for item in batch] - - images = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=sample_image_size, - width=sample_image_size, - num_inference_steps=sample_steps - ).images - - for i, image in enumerate(images): - image.save(image_name[i]) - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if len(missing_data) == 0: + return + + batched_data = [ + missing_data[i:i+sample_batch_size] + for i in range(0, len(missing_data), sample_batch_size) + ] + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + with torch.inference_mode(): + for batch in batched_data: + image_name = [item.class_image_path for item in batch] + prompt = [item.cprompt for item in batch] + nprompt = [item.nprompt for item in batch] + + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=sample_image_size, + width=sample_image_size, + num_inference_steps=sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_models(pretrained_model_name_or_path: str): @@ -119,7 +121,7 @@ def add_placeholder_tokens( for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): embeddings.add_embed(placeholder_token_id, initializer_token_id) - return placeholder_token_ids + return placeholder_token_ids, initializer_token_ids def loss_step( @@ -127,7 +129,6 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - with_prior: bool, prior_loss_weight: float, seed: int, step: int, @@ -138,16 +139,23 @@ def loss_step( latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 + generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None + # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + noise = torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), - generator=timesteps_gen, + generator=generator, device=latents.device, ) timesteps = timesteps.long() @@ -176,7 +184,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if with_prior: + if batch["with_prior"]: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -207,7 +215,6 @@ def train_loop( val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], sample_frequency: int = 10, - sample_steps: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, num_epochs: int = 100, @@ -251,7 +258,7 @@ def train_loop( for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset, sample_steps) + checkpointer.save_samples(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: checkpointer.checkpoint(global_step + global_step_offset, "training") @@ -353,7 +360,7 @@ def train_loop( if accelerator.is_main_process: print("Finished!") checkpointer.checkpoint(global_step + global_step_offset, "end") - checkpointer.save_samples(global_step + global_step_offset, sample_steps) + checkpointer.save_samples(global_step + global_step_offset) accelerator.end_training() except KeyboardInterrupt: -- cgit v1.2.3-54-g00ecf