From d673760fc671d665aadae3b032f8e99f21ab986d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 16 Feb 2023 09:16:05 +0100 Subject: Integrated WIP UniPC scheduler --- training/functional.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index e1035ce..b7ea90d 100644 --- a/training/functional.py +++ b/training/functional.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from tqdm.auto import tqdm from PIL import Image @@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer +from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from training.util import AverageMeter @@ -79,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str): vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( + sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') embeddings = patch_managed_embeddings(text_encoder) @@ -93,7 +94,7 @@ def save_samples( text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, - sample_scheduler: DPMSolverMultistepScheduler, + sample_scheduler: UniPCMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], output_dir: Path, @@ -180,7 +181,7 @@ def generate_class_images( vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, - sample_scheduler: DPMSolverMultistepScheduler, + sample_scheduler: UniPCMultistepScheduler, train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, @@ -284,6 +285,7 @@ def loss_step( device=latents.device, generator=generator ) + bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -351,6 +353,7 @@ def train_loop( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], loss_step: LossCallable, + no_val: bool = False, sample_frequency: int = 10, checkpoint_frequency: int = 50, global_step_offset: int = 0, @@ -406,9 +409,15 @@ def train_loop( for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: + local_progress_bar.clear() + global_progress_bar.clear() + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: + local_progress_bar.clear() + global_progress_bar.clear() + on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -463,7 +472,7 @@ def train_loop( on_after_epoch(lr_scheduler.get_last_lr()[0]) - if val_dataloader is not None: + if val_dataloader is not None and not no_val: model.eval() cur_loss_val = AverageMeter() @@ -498,11 +507,11 @@ def train_loop( accelerator.log(logs, step=global_step) - local_progress_bar.clear() - global_progress_bar.clear() - if accelerator.is_main_process: if avg_acc_val.avg.item() > best_acc_val: + local_progress_bar.clear() + global_progress_bar.clear() + accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") @@ -513,6 +522,9 @@ def train_loop( else: if accelerator.is_main_process: if avg_acc.avg.item() > best_acc: + local_progress_bar.clear() + global_progress_bar.clear() + accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") @@ -550,6 +562,7 @@ def train( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, strategy: TrainingStrategy, + no_val: bool = False, num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, @@ -604,6 +617,7 @@ def train( lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, + no_val=no_val, loss_step=loss_step_, sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, -- cgit v1.2.3-54-g00ecf