From 0f493e1ac8406de061861ed390f283e821180e79 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 11:26:31 +0200 Subject: Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion --- textual_inversion.py | 308 +++++++++++++++++++++++++++------------------------ 1 file changed, 164 insertions(+), 144 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 399d876..7a7d7fc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -3,6 +3,8 @@ import itertools import math import os import datetime +import logging +from pathlib import Path import numpy as np import torch @@ -13,12 +15,13 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler -from pipelines.stable_diffusion.no_check import NoCheck from PIL import Image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from slugify import slugify +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json import os @@ -44,10 +47,10 @@ def parse_args(): help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( - "--train_data_dir", + "--train_data_file", type=str, default=None, - help="A folder containing the training data." + help="A CSV file containing the training data." ) parser.add_argument( "--placeholder_token", @@ -145,6 +148,11 @@ def parse_args(): default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument( "--adam_beta1", type=float, @@ -225,7 +233,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=50, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -261,8 +269,8 @@ def parse_args(): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank - if args.train_data_dir is None: - raise ValueError("You must specify --train_data_dir") + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") @@ -333,53 +341,51 @@ class Checkpointer: @torch.no_grad() def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): print("Saving checkpoint for step %d..." % step) - with self.accelerator.autocast(): - if path is None: - checkpoints_path = f"{self.output_dir}/checkpoints" - os.makedirs(checkpoints_path, exist_ok=True) - - unwrapped = self.accelerator.unwrap_model(text_encoder) - - # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] - learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} - - filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) - if path is not None: - torch.save(learned_embeds_dict, path) - else: - torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") - torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") - del unwrapped - del learned_embeds + + if path is None: + checkpoints_path = f"{self.output_dir}/checkpoints" + os.makedirs(checkpoints_path, exist_ok=True) + + unwrapped = self.accelerator.unwrap_model(text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) + if path is not None: + torch.save(learned_embeds_dict, path) + else: + torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") + torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") + + del unwrapped + del learned_embeds @torch.no_grad() - def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): - samples_path = f"{self.output_dir}/samples/{mode}" - os.makedirs(samples_path, exist_ok=True) - checker = NoCheck() + def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): + samples_path = Path(self.output_dir).joinpath("samples") unwrapped = self.accelerator.unwrap_model(text_encoder) + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + # Save a sample image - pipeline = StableDiffusionPipeline( + pipeline = VlpnStableDiffusion( text_encoder=unwrapped, vae=self.vae, unet=self.unet, tokenizer=self.tokenizer, - scheduler=LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ), - safety_checker=NoCheck(), + scheduler=scheduler, feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(self.accelerator.device) pipeline.enable_attention_slicing() - data = { - "training": self.datamodule.train_dataloader(), - "validation": self.datamodule.val_dataloader(), - }[mode] + train_data = self.datamodule.train_dataloader() + val_data = self.datamodule.val_dataloader() - if mode == "validation" and self.stable_sample_batches > 0 and step > 0: + if self.stable_sample_batches > 0: stable_latents = torch.randn( (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), device=pipeline.device, @@ -387,14 +393,17 @@ class Checkpointer: ) all_samples = [] - filename = f"stable_step_%d.png" % (step) + file_path = samples_path.joinpath("stable", f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(val_data) # Generate and save stable samples for i in range(0, self.stable_sample_batches): prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] + batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] + + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) with self.accelerator.autocast(): samples = pipeline( @@ -405,67 +414,64 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, + generator=generator, output_type='pil' )["sample"] all_samples += samples + + del generator del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) - image_grid.save(f"{samples_path}/{filename}") + image_grid.save(file_path) del all_samples del image_grid del stable_latents - all_samples = [] - filename = f"step_%d.png" % (step) + for data, pool in [(val_data, "val"), (train_data, "train")]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(data) - # Generate and save random samples - for i in range(0, self.random_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] + for i in range(0, self.random_sample_batches): + prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( + batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] - with self.accelerator.autocast(): - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - all_samples += samples - del samples + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + height=self.sample_image_size, + width=self.sample_image_size, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + generator=generator, + output_type='pil' + )["sample"] - image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) - image_grid.save(f"{samples_path}/{filename}") + all_samples += samples - del all_samples - del image_grid + del generator + del samples - del checker - del unwrapped - del pipeline - torch.cuda.empty_cache() + image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) + image_grid.save(file_path) + del all_samples + del image_grid -class ImageToLatents(): - def __init__(self, vae): - self.vae = vae - self.encoded_pixel_values_cache = {} + del unwrapped + del scheduler + del pipeline - @torch.no_grad() - def __call__(self, batch): - key = "|".join(batch["key"]) - if self.encoded_pixel_values_cache.get(key, None) is None: - self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist - latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215 - return latents + if torch.cuda.is_available(): + torch.cuda.empty_cache() def main(): @@ -473,17 +479,17 @@ def main(): global_step_offset = 0 if args.resume_from is not None: - basepath = f"{args.resume_from}" + basepath = Path(args.resume_from) print("Resuming state from %s" % args.resume_from) - with open(f"{basepath}/resume.json", 'r') as f: + with open(basepath.joinpath("resume.json"), 'r') as f: state = json.load(f) global_step_offset = state["args"].get("global_step", 0) print("We've trained %d steps so far" % global_step_offset) else: now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" - os.makedirs(basepath, exist_ok=True) + basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) + basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, @@ -492,6 +498,8 @@ def main(): mixed_precision=args.mixed_precision ) + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) @@ -570,8 +578,19 @@ def main(): args.train_batch_size * accelerator.num_processes ) + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + # Initialize the optimizer - optimizer = torch.optim.AdamW( + optimizer = optimizer_class( text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), @@ -585,7 +604,7 @@ def main(): ) datamodule = CSVDataModule( - data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, + data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, center_crop=args.center_crop) @@ -608,13 +627,12 @@ def main(): sample_batch_size=args.sample_batch_size, random_sample_batches=args.random_sample_batches, stable_sample_batches=args.stable_sample_batches, - seed=args.seed + seed=args.seed or torch.random.seed() ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -643,9 +661,10 @@ def main(): (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil( - args.max_train_steps / num_update_steps_per_epoch) + + num_val_steps_per_epoch = len(val_dataloader) + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + val_steps = num_val_steps_per_epoch * num_epochs # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -656,7 +675,7 @@ def main(): total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") - logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") @@ -666,22 +685,22 @@ def main(): global_step = 0 min_val_loss = np.inf - imageToLatents = ImageToLatents(vae) - - checkpointer.save_samples( - "validation", - 0, - text_encoder, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + if accelerator.is_main_process: + checkpointer.save_samples( + 0, + text_encoder, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Global steps") + local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process) + local_progress_bar.set_description("Batch X out of Y") - local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) - local_progress_bar.set_description("Steps") + global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) + global_progress_bar.set_description("Total progress") try: - for epoch in range(args.num_train_epochs): + for epoch in range(num_epochs): + local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") local_progress_bar.reset() text_encoder.train() @@ -689,27 +708,30 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): - with accelerator.autocast(): - # Convert images to latent space - latents = imageToLatents(batch) + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, - (bsz,), device=latents.device).long() + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) @@ -727,32 +749,27 @@ def main(): optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss = loss.detach().item() train_loss += loss # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - progress_bar.update(1) local_progress_bar.update(1) + global_progress_bar.update(1) global_step += 1 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: - progress_bar.clear() local_progress_bar.clear() + global_progress_bar.clear() checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) - checkpointer.save_samples( - "training", - global_step + global_step_offset, - text_encoder, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} local_progress_bar.set_postfix(**logs) @@ -762,17 +779,21 @@ def main(): train_loss /= len(train_dataloader) + accelerator.wait_for_everyone() + text_encoder.eval() val_loss = 0.0 for step, batch in enumerate(val_dataloader): - with torch.no_grad(), accelerator.autocast(): - latents = imageToLatents(batch) + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, - (bsz,), device=latents.device).long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -782,14 +803,15 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + with accelerator.autocast(): + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss if accelerator.sync_gradients: - progress_bar.update(1) local_progress_bar.update(1) + global_progress_bar.update(1) logs = {"mode": "validation", "loss": loss} local_progress_bar.set_postfix(**logs) @@ -798,21 +820,19 @@ def main(): accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) - progress_bar.clear() local_progress_bar.clear() + global_progress_bar.clear() if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) min_val_loss = val_loss - checkpointer.save_samples( - "validation", - global_step + global_step_offset, - text_encoder, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - - accelerator.wait_for_everyone() + if accelerator.is_main_process: + checkpointer.save_samples( + global_step + global_step_offset, + text_encoder, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: -- cgit v1.2.3-54-g00ecf