From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- training/functional.py | 221 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 149 insertions(+), 72 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index fd3f9f4..f68faf9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -14,7 +14,13 @@ import numpy as np from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + DDPMScheduler, + UniPCMultistepScheduler, + SchedulerMixin, +) from tqdm.auto import tqdm @@ -33,11 +39,12 @@ from util.noise import perlin_noise def const(result=None): def fn(*args, **kwargs): return result + return fn @dataclass -class TrainingCallbacks(): +class TrainingCallbacks: on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], Any] = const() @@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol): train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - **kwargs - ) -> Tuple: ... + **kwargs, + ) -> Tuple: + ... @dataclass -class TrainingStrategy(): +class TrainingStrategy: callbacks: Callable[..., TrainingCallbacks] prepare: TrainingStrategyPrepareCallable def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): - tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + tokenizer = MultiCLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, subfolder="tokenizer" + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype + ) + vae = AutoencoderKL.from_pretrained( + pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype + ) + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype + ) + noise_scheduler = DDPMScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) + sample_scheduler = UniPCMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler @@ -113,7 +133,9 @@ def save_samples( generator = torch.Generator(device=accelerator.device).manual_seed(seed) - datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] + datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ + ("train", train_dataloader, None) + ] if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) @@ -124,17 +146,11 @@ def save_samples( file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) - batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) - prompt_ids = [ - prompt - for batch in batches - for prompt in batch["prompt_ids"] - ] - nprompt_ids = [ - prompt - for batch in batches - for prompt in batch["nprompt_ids"] - ] + batches = list( + itertools.islice(itertools.cycle(data), batch_size * num_batches) + ) + prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] + nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] with torch.inference_mode(): for i in range(num_batches): @@ -165,7 +181,9 @@ def save_samples( pass image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) - image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] + image_grid = pipeline.numpy_to_pil( + image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() + )[0] image_grid.save(file_path, quality=85) del generator, pipeline @@ -184,15 +202,17 @@ def generate_class_images( train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, - sample_steps: int + sample_steps: int, ): - missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] + missing_data = [ + item for item in train_dataset.items if not item.class_image_path.exists() + ] if len(missing_data) == 0: return batched_data = [ - missing_data[i:i+sample_batch_size] + missing_data[i : i + sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] @@ -216,7 +236,7 @@ def generate_class_images( negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, - num_inference_steps=sample_steps + num_inference_steps=sample_steps, ).images for i, image in enumerate(images): @@ -245,8 +265,12 @@ def add_placeholder_tokens( embeddings.resize(len(tokenizer)) - for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) + for placeholder_token_id, initializer_token_id in zip( + placeholder_token_ids, initializer_token_ids + ): + embeddings.add_embed( + placeholder_token_id, initializer_token_id, initializer_noise + ) return placeholder_token_ids, initializer_token_ids @@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler): # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) @@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler): def get_original( - noise_scheduler, - model_output, - sample: torch.FloatTensor, - timesteps: torch.IntTensor + noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor ): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() while len(sqrt_alphas_cumprod.shape) < len(sample.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(sample.shape) - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) @@ -329,7 +358,9 @@ def loss_step( eval: bool = False, ): images = batch["pixel_values"] - generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None + generator = ( + torch.Generator(device=images.device).manual_seed(seed + step) if eval else None + ) bsz = images.shape[0] # Convert images to latent space @@ -342,7 +373,7 @@ def loss_step( dtype=latents.dtype, layout=latents.layout, device=latents.device, - generator=generator + generator=generator, ) applied_noise = noise @@ -353,7 +384,7 @@ def loss_step( octaves=4, dtype=latents.dtype, device=latents.device, - generator=generator + generator=generator, ) if input_pertubation != 0: @@ -362,7 +393,7 @@ def loss_step( dtype=latents.dtype, layout=latents.layout, device=latents.device, - generator=generator + generator=generator, ) # Sample a random timestep for each image @@ -375,25 +406,27 @@ def loss_step( # Get the text embedding for conditioning encoder_hidden_states = get_extended_embeddings( - text_encoder, - batch["input_ids"], - batch["attention_mask"] + text_encoder, batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + model_pred = unet( + noisy_latents, timesteps, encoder_hidden_states, return_dict=False + )[0] if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( - text_encoder, - batch["negative_input_ids"], - batch["negative_attention_mask"] + text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] ) uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) - model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] - model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) + model_pred_uncond = unet( + noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False + )[0] + model_pred = model_pred_uncond + guidance_scale * ( + model_pred - model_pred_uncond + ) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -401,7 +434,9 @@ def loss_step( elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) acc = (model_pred == target).float().mean() @@ -414,7 +449,9 @@ def loss_step( loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") + prior_loss = F.mse_loss( + model_pred_prior.float(), target_prior.float(), reduction="none" + ) # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss @@ -433,7 +470,10 @@ def loss_step( if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) mse_loss_weights = ( - torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + / snr ) loss = loss * mse_loss_weights @@ -447,8 +487,14 @@ def loss_step( class LossCallable(Protocol): - def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], - eval: bool = False) -> Tuple[Any, Any, int]: ... + def __call__( + self, + step: int, + batch: dict[Any, Any], + cache: dict[str, Any], + eval: bool = False, + ) -> Tuple[Any, Any, int]: + ... def train_loop( @@ -472,9 +518,14 @@ def train_loop( avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): - num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_val_steps_per_epoch = math.ceil( - len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 + num_training_steps_per_epoch = math.ceil( + len(train_dataloader) / gradient_accumulation_steps + ) + num_val_steps_per_epoch = ( + math.ceil(len(val_dataloader) / gradient_accumulation_steps) + if val_dataloader is not None + else 0 + ) num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs @@ -488,14 +539,14 @@ def train_loop( local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, - dynamic_ncols=True + dynamic_ncols=True, ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, - dynamic_ncols=True + dynamic_ncols=True, ) global_progress_bar.set_description("Total progress") @@ -513,7 +564,9 @@ def train_loop( try: import dadaptation - isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) + isDadaptation = isinstance( + optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) + ) except ImportError: pass @@ -565,7 +618,10 @@ def train_loop( label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr if isDadaptation: - lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + lr = ( + optimizer.param_groups[i]["d"] + * optimizer.param_groups[i]["lr"] + ) logs[f"d*lr/{label}"] = lr lrs[label] = lr @@ -573,8 +629,10 @@ def train_loop( local_progress_bar.set_postfix(**logs) - if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): - before_optimize_result = on_before_optimize(epoch) + if ((step + 1) % gradient_accumulation_steps == 0) or ( + (step + 1) == len(train_dataloader) + ): + before_optimize_result = on_before_optimize(cycle) optimizer.step() lr_scheduler.step() @@ -614,7 +672,9 @@ def train_loop( } local_progress_bar.set_postfix(**logs) - if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): + if ((step + 1) % gradient_accumulation_steps == 0) or ( + (step + 1) == len(val_dataloader) + ): local_progress_bar.update(1) global_progress_bar.update(1) @@ -634,7 +694,8 @@ def train_loop( global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" + ) on_checkpoint(global_step, "milestone") best_acc_val = avg_acc_val.max else: @@ -644,7 +705,8 @@ def train_loop( global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" + ) on_checkpoint(global_step, "milestone") best_acc = avg_acc.max @@ -700,17 +762,32 @@ def train( avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( - accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = strategy.prepare( + accelerator, + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + **kwargs, + ) vae.to(accelerator.device, dtype=dtype) vae.requires_grad_(False) vae.eval() - vae = torch.compile(vae, backend='hidet') + vae = torch.compile(vae, backend="hidet") if compile_unet: - unet = torch.compile(unet, backend='hidet') + unet = torch.compile(unet, backend="hidet") # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( -- cgit v1.2.3-54-g00ecf