From 632ce00b54ffeacfc18f44f10827f167ab3ac37c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 21:06:16 +0100 Subject: Restored functional trainer --- data/csv.py | 8 -- .../stable_diffusion/vlpn_stable_diffusion.py | 16 ++-- train_ti.py | 82 +++++------------ training/functional.py | 102 ++++++++++++++++----- training/util.py | 8 +- 5 files changed, 112 insertions(+), 104 deletions(-) diff --git a/data/csv.py b/data/csv.py index 5de3ac7..2a8115b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -15,9 +15,6 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt from models.clip.util import unify_input_ids -image_cache: dict[str, Image.Image] = {} - - interpolations = { "linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, @@ -27,14 +24,9 @@ interpolations = { def get_image(path): - if path in image_cache: - return image_cache[path] - image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") - image_cache[path] = image - return image diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 43141bd..3027421 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -162,8 +162,8 @@ class VlpnStableDiffusion(DiffusionPipeline): self, prompt: Union[str, List[str], List[int], List[List[int]]], negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], - width: Optional[int], - height: Optional[int], + width: int, + height: int, strength: float, callback_steps: Optional[int] ): @@ -324,19 +324,19 @@ class VlpnStableDiffusion(DiffusionPipeline): self, prompt: Union[str, List[str], List[int], List[List[int]]], negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, - num_images_per_prompt: Optional[int] = 1, + num_images_per_prompt: int = 1, strength: float = 0.8, height: Optional[int] = None, width: Optional[int] = None, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - output_type: Optional[str] = "pil", + output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, + callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. diff --git a/train_ti.py b/train_ti.py index 4bac736..77dec12 100644 --- a/train_ti.py +++ b/train_ti.py @@ -10,15 +10,13 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -import matplotlib.pyplot as plt from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem -from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train, generate_class_images, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler -from training.lr import LRFinder from training.util import save_args logger = get_logger(__name__) @@ -644,23 +642,33 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) - - vae.to(accelerator.device, dtype=weight_dtype) - - callbacks = textual_inversion_strategy( + trainer = partial( + train, accelerator=accelerator, unet=unet, text_encoder=text_encoder, - tokenizer=tokenizer, vae=vae, - sample_scheduler=sample_scheduler, + noise_scheduler=noise_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - output_dir=output_dir, + dtype=weight_dtype, seed=args.seed, + callbacks_fn=textual_inversion_strategy + ) + + trainer( + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + output_dir=output_dir, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, learning_rate=args.learning_rate, @@ -679,54 +687,6 @@ def main(): sample_image_size=args.sample_image_size, ) - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() - - callbacks.on_prepare() - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - args.num_class_images != 0, - args.prior_loss_weight, - args.seed, - ) - - if args.find_lr: - lr_finder = LRFinder( - accelerator=accelerator, - optimizer=optimizer, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - callbacks=callbacks, - ) - lr_finder.run(num_epochs=100, end_lr=1e3) - - plt.savefig(output_dir.joinpath("lr.png"), dpi=300) - plt.close() - else: - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - callbacks=callbacks, - ) - - accelerator.end_training() - if __name__ == "__main__": main() diff --git a/training/functional.py b/training/functional.py index c01595a..5984ffb 100644 --- a/training/functional.py +++ b/training/functional.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union, Optional +from typing import Callable, Any, Tuple, Union, Optional, Type from functools import partial from pathlib import Path import itertools @@ -32,7 +32,7 @@ def const(result=None): @dataclass class TrainingCallbacks(): - on_prepare: Callable[[float], None] = const() + on_prepare: Callable[[], None] = const() on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) @@ -220,28 +220,6 @@ def generate_class_images( torch.cuda.empty_cache() -def get_models(pretrained_model_name_or_path: str): - tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler') - - vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.set_use_memory_efficient_attention_xformers(True) - - embeddings = patch_managed_embeddings(text_encoder) - - vae.requires_grad_(False) - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - - return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings - - def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, @@ -508,3 +486,79 @@ def train_loop( if accelerator.is_main_process: print("Interrupted") on_checkpoint(global_step + global_step_offset, "end") + + +def train( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + vae: AutoencoderKL, + noise_scheduler: DDPMScheduler, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + dtype: torch.dtype, + seed: int, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + callbacks_fn: Callable[..., TrainingCallbacks], + num_train_epochs: int = 100, + sample_frequency: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, + with_prior_preservation: bool = False, + prior_loss_weight: float = 1.0, + **kwargs, +): + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) + + vae.to(accelerator.device, dtype=dtype) + + for model in (unet, text_encoder, vae): + model.requires_grad_(False) + model.eval() + + callbacks = callbacks_fn( + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + vae=vae, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + seed=seed, + **kwargs, + ) + + callbacks.on_prepare() + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + with_prior_preservation, + prior_loss_weight, + seed, + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=sample_frequency, + checkpoint_frequency=checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=num_train_epochs, + callbacks=callbacks, + ) + + accelerator.end_training() + accelerator.free_memory() diff --git a/training/util.py b/training/util.py index f46cc61..557b196 100644 --- a/training/util.py +++ b/training/util.py @@ -180,11 +180,13 @@ class EMAModel: @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): + parameters = list(parameters) + original_params = [p.clone() for p in parameters] + self.copy_to(parameters) + try: - parameters = list(parameters) - original_params = [p.clone() for p in parameters] - self.copy_to(parameters) yield finally: for o_param, param in zip(original_params, parameters): param.data.copy_(o_param.data) + del original_params -- cgit v1.2.3-54-g00ecf