diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py index ab2741a..67c2ab6 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -3,6 +3,60 @@ import torch.nn.functional as F | |||
3 | 3 | ||
4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
5 | 5 | ||
6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
7 | |||
8 | |||
9 | def generate_class_images( | ||
10 | accelerator, | ||
11 | text_encoder, | ||
12 | vae, | ||
13 | unet, | ||
14 | tokenizer, | ||
15 | scheduler, | ||
16 | data_train, | ||
17 | sample_batch_size, | ||
18 | sample_image_size, | ||
19 | sample_steps | ||
20 | ): | ||
21 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
22 | |||
23 | if len(missing_data) != 0: | ||
24 | batched_data = [ | ||
25 | missing_data[i:i+sample_batch_size] | ||
26 | for i in range(0, len(missing_data), sample_batch_size) | ||
27 | ] | ||
28 | |||
29 | pipeline = VlpnStableDiffusion( | ||
30 | text_encoder=text_encoder, | ||
31 | vae=vae, | ||
32 | unet=unet, | ||
33 | tokenizer=tokenizer, | ||
34 | scheduler=scheduler, | ||
35 | ).to(accelerator.device) | ||
36 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
37 | |||
38 | with torch.inference_mode(): | ||
39 | for batch in batched_data: | ||
40 | image_name = [item.class_image_path for item in batch] | ||
41 | prompt = [item.cprompt for item in batch] | ||
42 | nprompt = [item.nprompt for item in batch] | ||
43 | |||
44 | images = pipeline( | ||
45 | prompt=prompt, | ||
46 | negative_prompt=nprompt, | ||
47 | height=sample_image_size, | ||
48 | width=sample_image_size, | ||
49 | num_inference_steps=sample_steps | ||
50 | ).images | ||
51 | |||
52 | for i, image in enumerate(images): | ||
53 | image.save(image_name[i]) | ||
54 | |||
55 | del pipeline | ||
56 | |||
57 | if torch.cuda.is_available(): | ||
58 | torch.cuda.empty_cache() | ||
59 | |||
6 | 60 | ||
7 | def run_model( | 61 | def run_model( |
8 | vae: AutoencoderKL, | 62 | vae: AutoencoderKL, |