From 0f493e1ac8406de061861ed390f283e821180e79 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 11:26:31 +0200 Subject: Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion --- data/dreambooth/csv.py | 1 - data/textual_inversion/csv.py | 98 +++++++------- dreambooth.py | 26 +++- textual_inversion.py | 308 ++++++++++++++++++++++-------------------- 4 files changed, 230 insertions(+), 203 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 99bcf12..1676d35 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -2,7 +2,6 @@ import math import os import pandas as pd from pathlib import Path -import PIL import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 0d1e96e..f306c7a 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -1,11 +1,10 @@ import os import numpy as np import pandas as pd -import random -import PIL +from pathlib import Path +import math import pytorch_lightning as pl from PIL import Image -import torch from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms @@ -13,29 +12,32 @@ from torchvision import transforms class CSVDataModule(pl.LightningDataModule): def __init__(self, batch_size, - data_root, + data_file, tokenizer, size=512, repeats=100, interpolation="bicubic", placeholder_token="*", - flip_p=0.5, center_crop=False): super().__init__() - self.data_root = data_root + self.data_file = Path(data_file) + + if not self.data_file.is_file(): + raise ValueError("data_file must be a file") + + self.data_root = self.data_file.parent self.tokenizer = tokenizer self.size = size self.repeats = repeats self.placeholder_token = placeholder_token self.center_crop = center_crop - self.flip_p = flip_p self.interpolation = interpolation self.batch_size = batch_size def prepare_data(self): - metadata = pd.read_csv(f'{self.data_root}/list.csv') + metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] captions = [caption for caption in metadata['caption'].values] skips = [skip for skip in metadata['skip'].values] @@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule): self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, - flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + placeholder_token=self.placeholder_token, center_crop=self.center_crop) val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, - flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) + placeholder_token=self.placeholder_token, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) @@ -67,48 +69,54 @@ class CSVDataset(Dataset): size=512, repeats=1, interpolation="bicubic", - flip_p=0.5, placeholder_token="*", center_crop=False, + batch_size=1, ): self.data = data self.tokenizer = tokenizer - - self.num_images = len(self.data) - self._length = self.num_images * repeats - self.placeholder_token = placeholder_token + self.batch_size = batch_size + self.cache = {} - self.size = size - self.center_crop = center_crop - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - self.flip = transforms.RandomHorizontalFlip(p=flip_p) + self.num_instance_images = len(self.data) + self._length = self.num_instance_images * repeats - self.cache = {} + self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, + }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=self.interpolation), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): - return self._length + return math.ceil(self._length / self.batch_size) * self.batch_size - def get_example(self, i, flipped): - image_path, text = self.data[i % self.num_images] + def get_example(self, i): + image_path, text = self.data[i % self.num_instance_images] if image_path in self.cache: return self.cache[image_path] example = {} - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") + instance_image = Image.open(image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") text = text.format(self.placeholder_token) - example["prompt"] = text + example["prompts"] = text + example["pixel_values"] = instance_image example["input_ids"] = self.tokenizer( text, padding="max_length", @@ -117,29 +125,15 @@ class CSVDataset(Dataset): return_tensors="pt", ).input_ids[0] - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] - - image = Image.fromarray(img) - image = image.resize((self.size, self.size), - resample=self.interpolation) - image = self.flip(image) - image = np.array(image).astype(np.uint8) - image = (image / 127.5 - 1.0).astype(np.float32) - - example["key"] = "-".join([image_path, "-", str(flipped)]) - example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) - self.cache[image_path] = example return example def __getitem__(self, i): - flipped = random.choice([False, True]) - example = self.get_example(i, flipped) + example = {} + unprocessed_example = self.get_example(i) + + example["prompts"] = unprocessed_example["prompts"] + example["input_ids"] = unprocessed_example["input_ids"] + example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) + return example diff --git a/dreambooth.py b/dreambooth.py index 4d7366c..744d1bc 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -14,12 +14,14 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from pipelines.stable_diffusion.no_check import NoCheck from PIL import Image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from slugify import slugify +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json from data.dreambooth.csv import CSVDataModule @@ -215,7 +217,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=80, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -377,15 +379,16 @@ class Checkpointer: samples_path = Path(self.output_dir).joinpath("samples") unwrapped = self.accelerator.unwrap_model(self.unet) - pipeline = StableDiffusionPipeline( + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + pipeline = VlpnStableDiffusion( text_encoder=self.text_encoder, vae=self.vae, unet=unwrapped, tokenizer=self.tokenizer, - scheduler=LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ), - safety_checker=NoCheck(), + scheduler=scheduler, feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(self.accelerator.device) pipeline.enable_attention_slicing() @@ -411,6 +414,8 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) + with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -420,10 +425,13 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, + generator=generator, output_type='pil' )["sample"] all_samples += samples + + del generator del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) @@ -444,6 +452,8 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) + with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -452,10 +462,13 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, + generator=generator, output_type='pil' )["sample"] all_samples += samples + + del generator del samples image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) @@ -465,6 +478,7 @@ class Checkpointer: del image_grid del unwrapped + del scheduler del pipeline if torch.cuda.is_available(): diff --git a/textual_inversion.py b/textual_inversion.py index 399d876..7a7d7fc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -3,6 +3,8 @@ import itertools import math import os import datetime +import logging +from pathlib import Path import numpy as np import torch @@ -13,12 +15,13 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler -from pipelines.stable_diffusion.no_check import NoCheck from PIL import Image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from slugify import slugify +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json import os @@ -44,10 +47,10 @@ def parse_args(): help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( - "--train_data_dir", + "--train_data_file", type=str, default=None, - help="A folder containing the training data." + help="A CSV file containing the training data." ) parser.add_argument( "--placeholder_token", @@ -145,6 +148,11 @@ def parse_args(): default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument( "--adam_beta1", type=float, @@ -225,7 +233,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=50, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -261,8 +269,8 @@ def parse_args(): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank - if args.train_data_dir is None: - raise ValueError("You must specify --train_data_dir") + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") @@ -333,53 +341,51 @@ class Checkpointer: @torch.no_grad() def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): print("Saving checkpoint for step %d..." % step) - with self.accelerator.autocast(): - if path is None: - checkpoints_path = f"{self.output_dir}/checkpoints" - os.makedirs(checkpoints_path, exist_ok=True) - - unwrapped = self.accelerator.unwrap_model(text_encoder) - - # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] - learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} - - filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) - if path is not None: - torch.save(learned_embeds_dict, path) - else: - torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") - torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") - del unwrapped - del learned_embeds + + if path is None: + checkpoints_path = f"{self.output_dir}/checkpoints" + os.makedirs(checkpoints_path, exist_ok=True) + + unwrapped = self.accelerator.unwrap_model(text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) + if path is not None: + torch.save(learned_embeds_dict, path) + else: + torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") + torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") + + del unwrapped + del learned_embeds @torch.no_grad() - def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): - samples_path = f"{self.output_dir}/samples/{mode}" - os.makedirs(samples_path, exist_ok=True) - checker = NoCheck() + def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): + samples_path = Path(self.output_dir).joinpath("samples") unwrapped = self.accelerator.unwrap_model(text_encoder) + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + # Save a sample image - pipeline = StableDiffusionPipeline( + pipeline = VlpnStableDiffusion( text_encoder=unwrapped, vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, - scheduler=LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ), - safety_checker=NoCheck(), + scheduler=scheduler, feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(self.accelerator.device) pipeline.enable_attention_slicing() - data = { - "training": self.datamodule.train_dataloader(), - "validation": self.datamodule.val_dataloader(), - }[mode] + train_data = self.datamodule.train_dataloader() + val_data = self.datamodule.val_dataloader() - if mode == "validation" and self.stable_sample_batches > 0 and step > 0: + if self.stable_sample_batches > 0: stable_latents = torch.randn( (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), device=pipeline.device, @@ -387,14 +393,17 @@ class Checkpointer: ) all_samples = [] - filename = f"stable_step_%d.png" % (step) + file_path = samples_path.joinpath("stable", f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(val_data) # Generate and save stable samples for i in range(0, self.stable_sample_batches): prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] + batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] + + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) with self.accelerator.autocast(): samples = pipeline( @@ -405,67 +414,64 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, + generator=generator, output_type='pil' )["sample"] all_samples += samples + + del generator del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) - image_grid.save(f"{samples_path}/{filename}") + image_grid.save(file_path) del all_samples del image_grid del stable_latents - all_samples = [] - filename = f"step_%d.png" % (step) + for data, pool in [(val_data, "val"), (train_data, "train")]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(data) - # Generate and save random samples - for i in range(0, self.random_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] + for i in range(0, self.random_sample_batches): + prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( + batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] - with self.accelerator.autocast(): - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - all_samples += samples - del samples + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + height=self.sample_image_size, + width=self.sample_image_size, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + generator=generator, + output_type='pil' + )["sample"] - image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) - image_grid.save(f"{samples_path}/{filename}") + all_samples += samples - del all_samples - del image_grid + del generator + del samples - del checker - del unwrapped - del pipeline - torch.cuda.empty_cache() + image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) + image_grid.save(file_path) + del all_samples + del image_grid -class ImageToLatents(): - def __init__(self, vae): - self.vae = vae - self.encoded_pixel_values_cache = {} + del unwrapped + del scheduler + del pipeline - @torch.no_grad() - def __call__(self, batch): - key = "|".join(batch["key"]) - if self.encoded_pixel_values_cache.get(key, None) is None: - self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist - latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215 - return latents + if torch.cuda.is_available(): + torch.cuda.empty_cache() def main(): @@ -473,17 +479,17 @@ def main(): global_step_offset = 0 if args.resume_from is not None: - basepath = f"{args.resume_from}" + basepath = Path(args.resume_from) print("Resuming state from %s" % args.resume_from) - with open(f"{basepath}/resume.json", 'r') as f: + with open(basepath.joinpath("resume.json"), 'r') as f: state = json.load(f) global_step_offset = state["args"].get("global_step", 0) print("We've trained %d steps so far" % global_step_offset) else: now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" - os.makedirs(basepath, exist_ok=True) + basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) + basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, @@ -492,6 +498,8 @@ def main(): mixed_precision=args.mixed_precision ) + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) @@ -570,8 +578,19 @@ def main(): args.train_batch_size * accelerator.num_processes ) + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + # Initialize the optimizer - optimizer = torch.optim.AdamW( + optimizer = optimizer_class( text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), @@ -585,7 +604,7 @@ def main(): ) datamodule = CSVDataModule( - data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, + data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, center_crop=args.center_crop) @@ -608,13 +627,12 @@ def main(): sample_batch_size=args.sample_batch_size, random_sample_batches=args.random_sample_batches, stable_sample_batches=args.stable_sample_batches, - seed=args.seed + seed=args.seed or torch.random.seed() ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -643,9 +661,10 @@ def main(): (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil( - args.max_train_steps / num_update_steps_per_epoch) + + num_val_steps_per_epoch = len(val_dataloader) + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + val_steps = num_val_steps_per_epoch * num_epochs # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -656,7 +675,7 @@ def main(): total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") - logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") @@ -666,22 +685,22 @@ def main(): global_step = 0 min_val_loss = np.inf - imageToLatents = ImageToLatents(vae) - - checkpointer.save_samples( - "validation", - 0, - text_encoder, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + if accelerator.is_main_process: + checkpointer.save_samples( + 0, + text_encoder, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Global steps") + local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process) + local_progress_bar.set_description("Batch X out of Y") - local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) - local_progress_bar.set_description("Steps") + global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) + global_progress_bar.set_description("Total progress") try: - for epoch in range(args.num_train_epochs): + for epoch in range(num_epochs): + local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") local_progress_bar.reset() text_encoder.train() @@ -689,27 +708,30 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): - with accelerator.autocast(): - # Convert images to latent space - latents = imageToLatents(batch) + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, - (bsz,), device=latents.device).long() + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) @@ -727,32 +749,27 @@ def main(): optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss = loss.detach().item() train_loss += loss # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - progress_bar.update(1) local_progress_bar.update(1) + global_progress_bar.update(1) global_step += 1 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: - progress_bar.clear() local_progress_bar.clear() + global_progress_bar.clear() checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) - checkpointer.save_samples( - "training", - global_step + global_step_offset, - text_encoder, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} local_progress_bar.set_postfix(**logs) @@ -762,17 +779,21 @@ def main(): train_loss /= len(train_dataloader) + accelerator.wait_for_everyone() + text_encoder.eval() val_loss = 0.0 for step, batch in enumerate(val_dataloader): - with torch.no_grad(), accelerator.autocast(): - latents = imageToLatents(batch) + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, - (bsz,), device=latents.device).long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -782,14 +803,15 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + with accelerator.autocast(): + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss if accelerator.sync_gradients: - progress_bar.update(1) local_progress_bar.update(1) + global_progress_bar.update(1) logs = {"mode": "validation", "loss": loss} local_progress_bar.set_postfix(**logs) @@ -798,21 +820,19 @@ def main(): accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) - progress_bar.clear() local_progress_bar.clear() + global_progress_bar.clear() if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) min_val_loss = val_loss - checkpointer.save_samples( - "validation", - global_step + global_step_offset, - text_encoder, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - - accelerator.wait_for_everyone() + if accelerator.is_main_process: + checkpointer.save_samples( + global_step + global_step_offset, + text_encoder, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: -- cgit v1.2.3-70-g09d2