diff options
-rw-r--r-- | training/functional.py | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py index 2dcfbb8..2f7f837 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -9,13 +9,14 @@ import itertools | |||
9 | import torch | 9 | import torch |
10 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
11 | from torch.utils.data import DataLoader | 11 | from torch.utils.data import DataLoader |
12 | from torchvision.utils import make_grid | ||
13 | import numpy as np | ||
12 | 14 | ||
13 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
14 | from transformers import CLIPTextModel | 16 | from transformers import CLIPTextModel |
15 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin | 17 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin |
16 | 18 | ||
17 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
18 | from PIL import Image | ||
19 | 20 | ||
20 | from data.csv import VlpnDataset | 21 | from data.csv import VlpnDataset |
21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
@@ -64,14 +65,6 @@ class TrainingStrategy(): | |||
64 | prepare: TrainingStrategyPrepareCallable | 65 | prepare: TrainingStrategyPrepareCallable |
65 | 66 | ||
66 | 67 | ||
67 | def make_grid(images, rows, cols): | ||
68 | w, h = images[0].size | ||
69 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
70 | for i, image in enumerate(images): | ||
71 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
72 | return grid | ||
73 | |||
74 | |||
75 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | 68 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): |
76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 69 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 70 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -107,7 +100,6 @@ def save_samples( | |||
107 | print(f"Saving samples for step {step}...") | 100 | print(f"Saving samples for step {step}...") |
108 | 101 | ||
109 | grid_cols = min(batch_size, 4) | 102 | grid_cols = min(batch_size, 4) |
110 | grid_rows = (num_batches * batch_size) // grid_cols | ||
111 | 103 | ||
112 | pipeline = VlpnStableDiffusion( | 104 | pipeline = VlpnStableDiffusion( |
113 | text_encoder=text_encoder, | 105 | text_encoder=text_encoder, |
@@ -159,12 +151,19 @@ def save_samples( | |||
159 | guidance_scale=guidance_scale, | 151 | guidance_scale=guidance_scale, |
160 | sag_scale=0, | 152 | sag_scale=0, |
161 | num_inference_steps=num_steps, | 153 | num_inference_steps=num_steps, |
162 | output_type='pil' | 154 | output_type=None, |
163 | ).images | 155 | ).images |
164 | 156 | ||
165 | all_samples += samples | 157 | all_samples.append(torch.from_numpy(samples)) |
158 | |||
159 | all_samples = torch.cat(all_samples) | ||
160 | |||
161 | for tracker in accelerator.trackers: | ||
162 | if tracker.name == "tensorboard": | ||
163 | tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | ||
166 | 164 | ||
167 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 165 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
166 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | ||
168 | image_grid.save(file_path, quality=85) | 167 | image_grid.save(file_path, quality=85) |
169 | 168 | ||
170 | del generator | 169 | del generator |