summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/functional.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py
index 2dcfbb8..2f7f837 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -9,13 +9,14 @@ import itertools
9import torch 9import torch
10import torch.nn.functional as F 10import torch.nn.functional as F
11from torch.utils.data import DataLoader 11from torch.utils.data import DataLoader
12from torchvision.utils import make_grid
13import numpy as np
12 14
13from accelerate import Accelerator 15from accelerate import Accelerator
14from transformers import CLIPTextModel 16from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin 17from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin
16 18
17from tqdm.auto import tqdm 19from tqdm.auto import tqdm
18from PIL import Image
19 20
20from data.csv import VlpnDataset 21from data.csv import VlpnDataset
21from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
@@ -64,14 +65,6 @@ class TrainingStrategy():
64 prepare: TrainingStrategyPrepareCallable 65 prepare: TrainingStrategyPrepareCallable
65 66
66 67
67def make_grid(images, rows, cols):
68 w, h = images[0].size
69 grid = Image.new('RGB', size=(cols*w, rows*h))
70 for i, image in enumerate(images):
71 grid.paste(image, box=(i % cols*w, i//cols*h))
72 return grid
73
74
75def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): 68def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0):
76 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 69 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
77 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 70 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
@@ -107,7 +100,6 @@ def save_samples(
107 print(f"Saving samples for step {step}...") 100 print(f"Saving samples for step {step}...")
108 101
109 grid_cols = min(batch_size, 4) 102 grid_cols = min(batch_size, 4)
110 grid_rows = (num_batches * batch_size) // grid_cols
111 103
112 pipeline = VlpnStableDiffusion( 104 pipeline = VlpnStableDiffusion(
113 text_encoder=text_encoder, 105 text_encoder=text_encoder,
@@ -159,12 +151,19 @@ def save_samples(
159 guidance_scale=guidance_scale, 151 guidance_scale=guidance_scale,
160 sag_scale=0, 152 sag_scale=0,
161 num_inference_steps=num_steps, 153 num_inference_steps=num_steps,
162 output_type='pil' 154 output_type=None,
163 ).images 155 ).images
164 156
165 all_samples += samples 157 all_samples.append(torch.from_numpy(samples))
158
159 all_samples = torch.cat(all_samples)
160
161 for tracker in accelerator.trackers:
162 if tracker.name == "tensorboard":
163 tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC")
166 164
167 image_grid = make_grid(all_samples, grid_rows, grid_cols) 165 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols)
166 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0]
168 image_grid.save(file_path, quality=85) 167 image_grid.save(file_path, quality=85)
169 168
170 del generator 169 del generator