summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/common.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py
index ab2741a..67c2ab6 100644
--- a/training/common.py
+++ b/training/common.py
@@ -3,6 +3,60 @@ import torch.nn.functional as F
3 3
4from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 4from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
5 5
6from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
7
8
9def generate_class_images(
10 accelerator,
11 text_encoder,
12 vae,
13 unet,
14 tokenizer,
15 scheduler,
16 data_train,
17 sample_batch_size,
18 sample_image_size,
19 sample_steps
20):
21 missing_data = [item for item in data_train if not item.class_image_path.exists()]
22
23 if len(missing_data) != 0:
24 batched_data = [
25 missing_data[i:i+sample_batch_size]
26 for i in range(0, len(missing_data), sample_batch_size)
27 ]
28
29 pipeline = VlpnStableDiffusion(
30 text_encoder=text_encoder,
31 vae=vae,
32 unet=unet,
33 tokenizer=tokenizer,
34 scheduler=scheduler,
35 ).to(accelerator.device)
36 pipeline.set_progress_bar_config(dynamic_ncols=True)
37
38 with torch.inference_mode():
39 for batch in batched_data:
40 image_name = [item.class_image_path for item in batch]
41 prompt = [item.cprompt for item in batch]
42 nprompt = [item.nprompt for item in batch]
43
44 images = pipeline(
45 prompt=prompt,
46 negative_prompt=nprompt,
47 height=sample_image_size,
48 width=sample_image_size,
49 num_inference_steps=sample_steps
50 ).images
51
52 for i, image in enumerate(images):
53 image.save(image_name[i])
54
55 del pipeline
56
57 if torch.cuda.is_available():
58 torch.cuda.empty_cache()
59
6 60
7def run_model( 61def run_model(
8 vae: AutoencoderKL, 62 vae: AutoencoderKL,