From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- data/csv.py | 36 +-- models/clip/embeddings.py | 6 +- models/clip/prompt.py | 38 --- models/clip/util.py | 34 +++ .../stable_diffusion/vlpn_stable_diffusion.py | 20 +- train_dreambooth.py | 7 +- train_ti.py | 268 ++++----------------- training/common.py | 205 +++++++++++++++- training/util.py | 13 +- 9 files changed, 334 insertions(+), 293 deletions(-) delete mode 100644 models/clip/prompt.py create mode 100644 models/clip/util.py diff --git a/data/csv.py b/data/csv.py index f5fc8e6..a3fef30 100644 --- a/data/csv.py +++ b/data/csv.py @@ -9,9 +9,10 @@ from PIL import Image from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms +from transformers import CLIPTokenizer from data.keywords import prompt_to_keywords, keywords_to_prompt -from models.clip.prompt import PromptProcessor +from models.clip.util import unify_input_ids image_cache: dict[str, Image.Image] = {} @@ -102,7 +103,7 @@ def generate_buckets( def collate_fn( num_class_images: int, weight_dtype: torch.dtype, - prompt_processor: PromptProcessor, + tokenizer: CLIPTokenizer, examples ): prompt_ids = [example["prompt_ids"] for example in examples] @@ -119,9 +120,9 @@ def collate_fn( pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - prompts = prompt_processor.unify_input_ids(prompt_ids) - nprompts = prompt_processor.unify_input_ids(nprompt_ids) - inputs = prompt_processor.unify_input_ids(input_ids) + prompts = unify_input_ids(tokenizer, prompt_ids) + nprompts = unify_input_ids(tokenizer, nprompt_ids) + inputs = unify_input_ids(tokenizer, input_ids) batch = { "prompt_ids": prompts.input_ids, @@ -148,7 +149,7 @@ class VlpnDataModule(): self, batch_size: int, data_file: str, - prompt_processor: PromptProcessor, + tokenizer: CLIPTokenizer, class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, @@ -179,7 +180,7 @@ class VlpnDataModule(): self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images - self.prompt_processor = prompt_processor + self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets self.bucket_step_size = bucket_step_size @@ -272,7 +273,7 @@ class VlpnDataModule(): self.data_val = self.pad_items(data_val) train_dataset = VlpnDataset( - self.data_train, self.prompt_processor, + self.data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, @@ -281,7 +282,7 @@ class VlpnDataModule(): ) val_dataset = VlpnDataset( - self.data_val, self.prompt_processor, + self.data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, repeat=self.valid_set_repeat, @@ -289,7 +290,7 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) - collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) + collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) self.train_dataloader = DataLoader( train_dataset, @@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset): def __init__( self, items: list[VlpnDataItem], - prompt_processor: PromptProcessor, + tokenizer: CLIPTokenizer, num_buckets: int = 1, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, @@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset): self.items = items * repeat self.batch_size = batch_size - self.prompt_processor = prompt_processor + self.tokenizer = tokenizer self.num_class_images = num_class_images self.size = size self.dropout = dropout @@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset): self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() + def get_input_ids(self, text: str): + return self.tokenizer(text, padding="do_not_pad").input_ids + def __len__(self): return self.length_ @@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset): example = {} - example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) - example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) + example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) + example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( + example["instance_prompt_ids"] = self.get_input_ids( keywords_to_prompt(item.prompt, self.dropout, True) ) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: - example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) + example["class_prompt_ids"] = self.get_input_ids(item.cprompt) example["class_images"] = image_transforms(get_image(item.class_image_path)) batch.append(example) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..761efbc 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor + self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() + self.temp_token_embedding = nn.Embedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, @@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds - def normalize(self, target: float = 0.4, lambda_: float = 1.0): + def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): + if target is None: + target = self.decay_target w = self.temp_token_embedding.weight pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) w[self.temp_token_ids] = F.normalize( diff --git a/models/clip/prompt.py b/models/clip/prompt.py deleted file mode 100644 index a7380be..0000000 --- a/models/clip/prompt.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Union, Optional - -import torch - -from transformers import CLIPTokenizer, CLIPTextModel - - -class PromptProcessor(): - def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - - def get_input_ids(self, prompt: Union[str, list[str]]): - return self.tokenizer( - prompt, - padding="do_not_pad", - ).input_ids - - def unify_input_ids(self, input_ids: list[list[int]]): - return self.tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - pad_to_multiple_of=self.tokenizer.model_max_length, - return_tensors="pt" - ) - - def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): - prompts = input_ids.shape[0] - - input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - if position_ids is not None: - position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - if attention_mask is not None: - attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - - text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] - text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) - return text_embeddings diff --git a/models/clip/util.py b/models/clip/util.py new file mode 100644 index 0000000..8de8c19 --- /dev/null +++ b/models/clip/util.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch + +from transformers import CLIPTokenizer, CLIPTextModel + + +def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): + return tokenizer.pad( + {"input_ids": input_ids}, + padding=True, + pad_to_multiple_of=tokenizer.model_max_length, + return_tensors="pt" + ) + + +def get_extended_embeddings( + text_encoder: CLIPTextModel, + input_ids: torch.LongTensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask=None +): + model_max_length = text_encoder.config.max_position_embeddings + prompts = input_ids.shape[0] + + input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) + if position_ids is not None: + position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) + if attention_mask is not None: + attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) + + text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] + text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) + return text_embeddings diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 6bc40e9..a5cfc60 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -22,7 +22,7 @@ from diffusers import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer -from models.clip.prompt import PromptProcessor +from models.clip.util import unify_input_ids, get_extended_embeddings logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - self.prompt_processor = PromptProcessor(tokenizer, text_encoder) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline): do_classifier_free_guidance: bool, device ): - text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt + if isinstance(prompt[0], str): + text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids + else: + text_input_ids = prompt + text_input_ids *= num_images_per_prompt if do_classifier_free_guidance: - unconditional_input_ids = self.prompt_processor.get_input_ids( - negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt + if isinstance(prompt[0], str): + unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids + else: + unconditional_input_ids = negative_prompt unconditional_input_ids *= num_images_per_prompt text_input_ids = unconditional_input_ids + text_input_ids - text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) + text_inputs = unify_input_ids(self.tokenizer, text_input_ids) text_input_ids = text_inputs.input_ids if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: @@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline): else: attention_mask = None - text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) + text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) return text_embeddings diff --git a/train_dreambooth.py b/train_dreambooth.py index 0fe590f..fbbe6c2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -27,7 +27,6 @@ from training.common import loss_step, generate_class_images, get_scheduler from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.embeddings import patch_managed_embeddings -from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -690,8 +689,6 @@ def main(): text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - prompt_processor = PromptProcessor(tokenizer, text_encoder) - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -751,7 +748,7 @@ def main(): datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - prompt_processor=prompt_processor, + tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, @@ -876,7 +873,7 @@ def main(): vae, noise_scheduler, unet, - prompt_processor, + text_encoder, args.num_class_images, args.prior_loss_weight, args.seed, diff --git a/train_ti.py b/train_ti.py index e18ee38..8c86586 100644 --- a/train_ti.py +++ b/train_ti.py @@ -21,11 +21,10 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, generate_class_images, get_scheduler +from training.common import loss_step, train_loop, generate_class_images, get_scheduler from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args from models.clip.embeddings import patch_managed_embeddings -from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -197,12 +196,6 @@ def parse_args(): type=int, default=100 ) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -409,7 +402,7 @@ def parse_args(): ) parser.add_argument( "--decay_target", - default=0.4, + default=None, type=float, help="Embedding decay target." ) @@ -668,8 +661,6 @@ def main(): text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - prompt_processor = PromptProcessor(tokenizer, text_encoder) - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -722,7 +713,7 @@ def main(): datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - prompt_processor=prompt_processor, + tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, @@ -759,13 +750,7 @@ def main(): args.sample_steps ) - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.find_lr: lr_scheduler = None @@ -781,7 +766,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, warmup_epochs=args.lr_warmup_epochs, - max_train_steps=args.max_train_steps, + num_train_epochs=args.num_train_epochs, num_update_steps_per_epoch=num_update_steps_per_epoch, gradient_accumulation_steps=args.gradient_accumulation_steps ) @@ -805,15 +790,6 @@ def main(): else: unet.eval() - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - - num_val_steps_per_epoch = len(val_dataloader) - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - val_steps = num_val_steps_per_epoch * num_epochs - @contextmanager def on_train(): try: @@ -842,19 +818,44 @@ def main(): min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) ) + if args.use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if args.use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + loop = partial( loss_step, vae, noise_scheduler, unet, - prompt_processor, + text_encoder, args.num_class_images, args.prior_loss_weight, args.seed, ) - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + checkpointer = Checkpointer( + weight_dtype=weight_dtype, + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + ema_embeddings=ema_embeddings, + scheduler=checkpoint_scheduler, + placeholder_token=args.placeholder_token, + new_ids=new_ids, + output_dir=basepath, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + if accelerator.is_main_process: config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) @@ -882,190 +883,27 @@ def main(): plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() - - quit() - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num Epochs = {num_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - - global_step = 0 - - avg_loss = AverageMeter() - avg_acc = AverageMeter() - - avg_loss_val = AverageMeter() - avg_acc_val = AverageMeter() - - max_acc_val = 0.0 - - checkpointer = Checkpointer( - weight_dtype=weight_dtype, - datamodule=datamodule, - accelerator=accelerator, - vae=vae, - unet=unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - ema_embeddings=ema_embeddings, - scheduler=checkpoint_scheduler, - placeholder_token=args.placeholder_token, - new_ids=new_ids, - output_dir=basepath, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed - ) - - local_progress_bar = tqdm( - range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True - ) - local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") - - global_progress_bar = tqdm( - range(args.max_train_steps + val_steps), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True - ) - global_progress_bar.set_description("Total progress") - - try: - for epoch in range(num_epochs): - if accelerator.is_main_process: - if epoch % args.sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) - - local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") - local_progress_bar.reset() - - text_encoder.train() - - with on_train(): - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(step, batch) - - accelerator.backward(loss) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - on_after_optimize(lr_scheduler.get_last_lr()[0]) - - if args.use_ema: - ema_embeddings.step( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - global_step += 1 - - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0], - } - if args.use_ema: - logs["ema_decay"] = ema_embeddings.decay - - accelerator.log(logs, step=global_step) - - local_progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - accelerator.wait_for_everyone() - - text_encoder.eval() - - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() - - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(step, batch, True) - - loss = loss.detach_() - acc = acc.detach_() - - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) - - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) - - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() - - accelerator.log(logs, step=global_step) - - local_progress_bar.clear() - global_progress_bar.clear() - - if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - checkpointer.checkpoint(global_step + global_step_offset, "milestone") - max_acc_val = avg_acc_val.avg.item() - - if (epoch + 1) % args.checkpoint_frequency == 0: - checkpointer.checkpoint(global_step + global_step_offset, "training") - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - - # Create the pipeline using using the trained modules and save it. - if accelerator.is_main_process: - print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint(global_step + global_step_offset, "end") - checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - accelerator.end_training() - - except KeyboardInterrupt: - if accelerator.is_main_process: - print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint(global_step + global_step_offset, "end") - save_args(basepath, args, { - "global_step": global_step + global_step_offset - }) - accelerator.end_training() - quit() + else: + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + checkpointer=checkpointer, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loop, + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_epochs=args.num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) if __name__ == "__main__": diff --git a/training/common.py b/training/common.py index 90cf910..842ac07 100644 --- a/training/common.py +++ b/training/common.py @@ -1,14 +1,30 @@ import math +from contextlib import _GeneratorContextManager, nullcontext +from typing import Callable, Any, Tuple, Union import torch import torch.nn.functional as F +from torch.utils.data import DataLoader +from accelerate import Accelerator +from transformers import CLIPTokenizer, CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from tqdm.auto import tqdm +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.util import get_extended_embeddings from training.optimization import get_one_cycle_schedule +from training.util import AverageMeter, CheckpointerBase + + +def noop(*args, **kwards): + pass + + +def noop_on_log(): + return {} def get_scheduler( @@ -22,10 +38,11 @@ def get_scheduler( cycles: int, warmup_epochs: int, optimizer: torch.optim.Optimizer, - max_train_steps: int, + num_train_epochs: int, num_update_steps_per_epoch: int, gradient_accumulation_steps: int, ): + num_train_steps = num_train_epochs * num_update_steps_per_epoch warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps if id == "one_cycle": @@ -33,7 +50,7 @@ def get_scheduler( lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, - num_training_steps=max_train_steps * gradient_accumulation_steps, + num_training_steps=num_train_steps * gradient_accumulation_steps, warmup=warmup_func, annealing=annealing_func, warmup_exp=warmup_exp, @@ -42,12 +59,12 @@ def get_scheduler( ) elif id == "cosine_with_restarts": cycles = cycles if cycles is not None else math.ceil( - math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) + math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps, - num_training_steps=max_train_steps * gradient_accumulation_steps, + num_training_steps=num_train_steps * gradient_accumulation_steps, num_cycles=cycles, ) else: @@ -55,7 +72,7 @@ def get_scheduler( id, optimizer=optimizer, num_warmup_steps=warmup_steps, - num_training_steps=max_train_steps * gradient_accumulation_steps, + num_training_steps=num_train_steps * gradient_accumulation_steps, ) return lr_scheduler @@ -117,12 +134,12 @@ def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, - prompt_processor, + text_encoder: CLIPTextModel, num_class_images: int, prior_loss_weight: float, seed: int, step: int, - batch, + batch: dict[str, Any], eval: bool = False ): # Convert images to latent space @@ -149,7 +166,8 @@ def loss_step( noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings( + encoder_hidden_states = get_extended_embeddings( + text_encoder, batch["input_ids"], batch["attention_mask"] ) @@ -185,3 +203,172 @@ def loss_step( acc = (model_pred == target).float().mean() return loss, acc, bsz + + +def train_loop( + accelerator: Accelerator, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + model: torch.nn.Module, + checkpointer: CheckpointerBase, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], + sample_frequency: int = 10, + sample_steps: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, + gradient_accumulation_steps: int = 1, + num_epochs: int = 100, + on_log: Callable[[], dict[str, Any]] = noop_on_log, + on_train: Callable[[], _GeneratorContextManager] = nullcontext, + on_before_optimize: Callable[[], None] = noop, + on_after_optimize: Callable[[float], None] = noop, + on_eval: Callable[[], _GeneratorContextManager] = nullcontext +): + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_train_steps = num_epochs * num_update_steps_per_epoch + + num_val_steps_per_epoch = len(val_dataloader) + num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) + num_val_steps = num_val_steps_per_epoch * num_epochs + + global_step = 0 + + avg_loss = AverageMeter() + avg_acc = AverageMeter() + + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + + max_acc_val = 0.0 + + local_progress_bar = tqdm( + range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) + local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") + + global_progress_bar = tqdm( + range(num_train_steps + num_val_steps), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) + global_progress_bar.set_description("Total progress") + + try: + for epoch in range(num_epochs): + if accelerator.is_main_process: + if epoch % sample_frequency == 0: + checkpointer.save_samples(global_step + global_step_offset, sample_steps) + + if epoch % checkpoint_frequency == 0 and epoch != 0: + checkpointer.checkpoint(global_step + global_step_offset, "training") + + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") + local_progress_bar.reset() + + model.train() + + with on_train(): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(model): + loss, acc, bsz = loss_step(step, batch) + + accelerator.backward(loss) + + on_before_optimize() + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + on_after_optimize(lr_scheduler.get_last_lr()[0]) + + local_progress_bar.update(1) + global_progress_bar.update(1) + + global_step += 1 + + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } + logs.update(on_log()) + + accelerator.log(logs, step=global_step) + + local_progress_bar.set_postfix(**logs) + + if global_step >= num_train_steps: + break + + accelerator.wait_for_everyone() + + model.eval() + + cur_loss_val = AverageMeter() + cur_acc_val = AverageMeter() + + with torch.inference_mode(): + with on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loss_step(step, batch, True) + + loss = loss.detach_() + acc = acc.detach_() + + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) + + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) + + local_progress_bar.update(1) + global_progress_bar.update(1) + + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) + + logs["val/cur_loss"] = cur_loss_val.avg.item() + logs["val/cur_acc"] = cur_acc_val.avg.item() + + accelerator.log(logs, step=global_step) + + local_progress_bar.clear() + global_progress_bar.clear() + + if accelerator.is_main_process: + if avg_acc_val.avg.item() > max_acc_val: + accelerator.print( + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone") + max_acc_val = avg_acc_val.avg.item() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + print("Finished!") + checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_samples(global_step + global_step_offset, sample_steps) + accelerator.end_training() + + except KeyboardInterrupt: + if accelerator.is_main_process: + print("Interrupted") + checkpointer.checkpoint(global_step + global_step_offset, "end") + accelerator.end_training() + quit() diff --git a/training/util.py b/training/util.py index 60d64f0..0ec2032 100644 --- a/training/util.py +++ b/training/util.py @@ -55,8 +55,19 @@ class CheckpointerBase: self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + @torch.no_grad() + def checkpoint(self, step: int, postfix: str): + pass + @torch.inference_mode() - def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + def save_samples( + self, + pipeline, + step: int, + num_inference_steps: int, + guidance_scale: float = 7.5, + eta: float = 0.0 + ): samples_path = Path(self.output_dir).joinpath("samples") train_data = self.datamodule.train_dataloader -- cgit v1.2.3-54-g00ecf