From 3353ffb64c280a938a0f2513d13b716c1fca8c02 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 17:10:06 +0100 Subject: Cleanup --- training/common.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) (limited to 'training') diff --git a/training/common.py b/training/common.py index ab2741a..67c2ab6 100644 --- a/training/common.py +++ b/training/common.py @@ -3,6 +3,60 @@ import torch.nn.functional as F from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion + + +def generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + scheduler, + data_train, + sample_batch_size, + sample_image_size, + sample_steps +): + missing_data = [item for item in data_train if not item.class_image_path.exists()] + + if len(missing_data) != 0: + batched_data = [ + missing_data[i:i+sample_batch_size] + for i in range(0, len(missing_data), sample_batch_size) + ] + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + with torch.inference_mode(): + for batch in batched_data: + image_name = [item.class_image_path for item in batch] + prompt = [item.cprompt for item in batch] + nprompt = [item.nprompt for item in batch] + + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=sample_image_size, + width=sample_image_size, + num_inference_steps=sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + def run_model( vae: AutoencoderKL, -- cgit v1.2.3-54-g00ecf