From 632ce00b54ffeacfc18f44f10827f167ab3ac37c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 21:06:16 +0100 Subject: Restored functional trainer --- training/functional.py | 102 +++++++++++++++++++++++++++++++++++++------------ training/util.py | 8 ++-- 2 files changed, 83 insertions(+), 27 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index c01595a..5984ffb 100644 --- a/training/functional.py +++ b/training/functional.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union, Optional +from typing import Callable, Any, Tuple, Union, Optional, Type from functools import partial from pathlib import Path import itertools @@ -32,7 +32,7 @@ def const(result=None): @dataclass class TrainingCallbacks(): - on_prepare: Callable[[float], None] = const() + on_prepare: Callable[[], None] = const() on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) @@ -220,28 +220,6 @@ def generate_class_images( torch.cuda.empty_cache() -def get_models(pretrained_model_name_or_path: str): - tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler') - - vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.set_use_memory_efficient_attention_xformers(True) - - embeddings = patch_managed_embeddings(text_encoder) - - vae.requires_grad_(False) - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - - return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings - - def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, @@ -508,3 +486,79 @@ def train_loop( if accelerator.is_main_process: print("Interrupted") on_checkpoint(global_step + global_step_offset, "end") + + +def train( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + vae: AutoencoderKL, + noise_scheduler: DDPMScheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + dtype: torch.dtype, + seed: int, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + callbacks_fn: Callable[..., TrainingCallbacks], + num_train_epochs: int = 100, + sample_frequency: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, + with_prior_preservation: bool = False, + prior_loss_weight: float = 1.0, + **kwargs, +): + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) + + vae.to(accelerator.device, dtype=dtype) + + for model in (unet, text_encoder, vae): + model.requires_grad_(False) + model.eval() + + callbacks = callbacks_fn( + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + vae=vae, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + seed=seed, + **kwargs, + ) + + callbacks.on_prepare() + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + with_prior_preservation, + prior_loss_weight, + seed, + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=sample_frequency, + checkpoint_frequency=checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=num_train_epochs, + callbacks=callbacks, + ) + + accelerator.end_training() + accelerator.free_memory() diff --git a/training/util.py b/training/util.py index f46cc61..557b196 100644 --- a/training/util.py +++ b/training/util.py @@ -180,11 +180,13 @@ class EMAModel: @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): + parameters = list(parameters) + original_params = [p.clone() for p in parameters] + self.copy_to(parameters) + try: - parameters = list(parameters) - original_params = [p.clone() for p in parameters] - self.copy_to(parameters) yield finally: for o_param, param in zip(original_params, parameters): param.data.copy_(o_param.data) + del original_params -- cgit v1.2.3-54-g00ecf