diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/common.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py index ab2741a..67c2ab6 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -3,6 +3,60 @@ import torch.nn.functional as F | |||
| 3 | 3 | ||
| 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 5 | 5 | ||
| 6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 7 | |||
| 8 | |||
| 9 | def generate_class_images( | ||
| 10 | accelerator, | ||
| 11 | text_encoder, | ||
| 12 | vae, | ||
| 13 | unet, | ||
| 14 | tokenizer, | ||
| 15 | scheduler, | ||
| 16 | data_train, | ||
| 17 | sample_batch_size, | ||
| 18 | sample_image_size, | ||
| 19 | sample_steps | ||
| 20 | ): | ||
| 21 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
| 22 | |||
| 23 | if len(missing_data) != 0: | ||
| 24 | batched_data = [ | ||
| 25 | missing_data[i:i+sample_batch_size] | ||
| 26 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 27 | ] | ||
| 28 | |||
| 29 | pipeline = VlpnStableDiffusion( | ||
| 30 | text_encoder=text_encoder, | ||
| 31 | vae=vae, | ||
| 32 | unet=unet, | ||
| 33 | tokenizer=tokenizer, | ||
| 34 | scheduler=scheduler, | ||
| 35 | ).to(accelerator.device) | ||
| 36 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 37 | |||
| 38 | with torch.inference_mode(): | ||
| 39 | for batch in batched_data: | ||
| 40 | image_name = [item.class_image_path for item in batch] | ||
| 41 | prompt = [item.cprompt for item in batch] | ||
| 42 | nprompt = [item.nprompt for item in batch] | ||
| 43 | |||
| 44 | images = pipeline( | ||
| 45 | prompt=prompt, | ||
| 46 | negative_prompt=nprompt, | ||
| 47 | height=sample_image_size, | ||
| 48 | width=sample_image_size, | ||
| 49 | num_inference_steps=sample_steps | ||
| 50 | ).images | ||
| 51 | |||
| 52 | for i, image in enumerate(images): | ||
| 53 | image.save(image_name[i]) | ||
| 54 | |||
| 55 | del pipeline | ||
| 56 | |||
| 57 | if torch.cuda.is_available(): | ||
| 58 | torch.cuda.empty_cache() | ||
| 59 | |||
| 6 | 60 | ||
| 7 | def run_model( | 61 | def run_model( |
| 8 | vae: AutoencoderKL, | 62 | vae: AutoencoderKL, |
