From a7154d94c46bf005682313a11a09e3c74d533f1c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 11 Apr 2023 18:02:20 +0200 Subject: Store sample images in Tensorboard as well --- training/functional.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/training/functional.py b/training/functional.py index 2dcfbb8..2f7f837 100644 --- a/training/functional.py +++ b/training/functional.py @@ -9,13 +9,14 @@ import itertools import torch import torch.nn.functional as F from torch.utils.data import DataLoader +from torchvision.utils import make_grid +import numpy as np from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin from tqdm.auto import tqdm -from PIL import Image from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion @@ -64,14 +65,6 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def make_grid(images, rows, cols): - w, h = images[0].size - grid = Image.new('RGB', size=(cols*w, rows*h)) - for i, image in enumerate(images): - grid.paste(image, box=(i % cols*w, i//cols*h)) - return grid - - def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') @@ -107,7 +100,6 @@ def save_samples( print(f"Saving samples for step {step}...") grid_cols = min(batch_size, 4) - grid_rows = (num_batches * batch_size) // grid_cols pipeline = VlpnStableDiffusion( text_encoder=text_encoder, @@ -159,12 +151,19 @@ def save_samples( guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, - output_type='pil' + output_type=None, ).images - all_samples += samples + all_samples.append(torch.from_numpy(samples)) + + all_samples = torch.cat(all_samples) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") - image_grid = make_grid(all_samples, grid_rows, grid_cols) + image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) + image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] image_grid.save(file_path, quality=85) del generator -- cgit v1.2.3-54-g00ecf