diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-11 18:02:20 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-11 18:02:20 +0200 |
| commit | a7154d94c46bf005682313a11a09e3c74d533f1c (patch) | |
| tree | 603d81c0b86b1fd3415ee9e3d9cb6b99fedd924b | |
| parent | Update (diff) | |
| download | textual-inversion-diff-a7154d94c46bf005682313a11a09e3c74d533f1c.tar.gz textual-inversion-diff-a7154d94c46bf005682313a11a09e3c74d533f1c.tar.bz2 textual-inversion-diff-a7154d94c46bf005682313a11a09e3c74d533f1c.zip | |
Store sample images in Tensorboard as well
| -rw-r--r-- | training/functional.py | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py index 2dcfbb8..2f7f837 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -9,13 +9,14 @@ import itertools | |||
| 9 | import torch | 9 | import torch |
| 10 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
| 11 | from torch.utils.data import DataLoader | 11 | from torch.utils.data import DataLoader |
| 12 | from torchvision.utils import make_grid | ||
| 13 | import numpy as np | ||
| 12 | 14 | ||
| 13 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 14 | from transformers import CLIPTextModel | 16 | from transformers import CLIPTextModel |
| 15 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin | 17 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin |
| 16 | 18 | ||
| 17 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | ||
| 19 | 20 | ||
| 20 | from data.csv import VlpnDataset | 21 | from data.csv import VlpnDataset |
| 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| @@ -64,14 +65,6 @@ class TrainingStrategy(): | |||
| 64 | prepare: TrainingStrategyPrepareCallable | 65 | prepare: TrainingStrategyPrepareCallable |
| 65 | 66 | ||
| 66 | 67 | ||
| 67 | def make_grid(images, rows, cols): | ||
| 68 | w, h = images[0].size | ||
| 69 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
| 70 | for i, image in enumerate(images): | ||
| 71 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
| 72 | return grid | ||
| 73 | |||
| 74 | |||
| 75 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | 68 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): |
| 76 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 69 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 77 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 70 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -107,7 +100,6 @@ def save_samples( | |||
| 107 | print(f"Saving samples for step {step}...") | 100 | print(f"Saving samples for step {step}...") |
| 108 | 101 | ||
| 109 | grid_cols = min(batch_size, 4) | 102 | grid_cols = min(batch_size, 4) |
| 110 | grid_rows = (num_batches * batch_size) // grid_cols | ||
| 111 | 103 | ||
| 112 | pipeline = VlpnStableDiffusion( | 104 | pipeline = VlpnStableDiffusion( |
| 113 | text_encoder=text_encoder, | 105 | text_encoder=text_encoder, |
| @@ -159,12 +151,19 @@ def save_samples( | |||
| 159 | guidance_scale=guidance_scale, | 151 | guidance_scale=guidance_scale, |
| 160 | sag_scale=0, | 152 | sag_scale=0, |
| 161 | num_inference_steps=num_steps, | 153 | num_inference_steps=num_steps, |
| 162 | output_type='pil' | 154 | output_type=None, |
| 163 | ).images | 155 | ).images |
| 164 | 156 | ||
| 165 | all_samples += samples | 157 | all_samples.append(torch.from_numpy(samples)) |
| 158 | |||
| 159 | all_samples = torch.cat(all_samples) | ||
| 160 | |||
| 161 | for tracker in accelerator.trackers: | ||
| 162 | if tracker.name == "tensorboard": | ||
| 163 | tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | ||
| 166 | 164 | ||
| 167 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 165 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
| 166 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | ||
| 168 | image_grid.save(file_path, quality=85) | 167 | image_grid.save(file_path, quality=85) |
| 169 | 168 | ||
| 170 | del generator | 169 | del generator |
