From fe3113451fdde72ddccfc71639f0a2a1e146209a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Mar 2023 07:11:51 +0100 Subject: Update --- .../stable_diffusion/vlpn_stable_diffusion.py | 2 +- train_lora.py | 23 ++++++++++++++------ train_ti.py | 8 +++---- training/functional.py | 6 +++++- training/strategy/lora.py | 25 ++++++++++------------ 5 files changed, 37 insertions(+), 27 deletions(-) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f426de1..4505a2a 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -394,7 +394,7 @@ class VlpnStableDiffusion(DiffusionPipeline): sag_scale: float = 0.75, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = "noise", + image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, diff --git a/train_lora.py b/train_lora.py index 6e72376..e65e7be 100644 --- a/train_lora.py +++ b/train_lora.py @@ -13,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor +from diffusers.models.cross_attention import LoRACrossAttnProcessor from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -291,6 +291,12 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--lora_rank", + type=int, + default=256, + help="LoRA rank.", + ) parser.add_argument( "--sample_frequency", type=int, @@ -420,10 +426,6 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) - vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() - unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -439,11 +441,18 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.lora_rank ) unet.set_attn_processor(lora_attn_procs) + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() + lora_layers = AttnProcsLayers(unet.attn_processors) if args.embeddings_dir is not None: diff --git a/train_ti.py b/train_ti.py index b9d6e56..81938c8 100644 --- a/train_ti.py +++ b/train_ti.py @@ -476,13 +476,10 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") - if args.num_vectors is None: - args.num_vectors = 1 - if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) - if len(args.placeholder_tokens) != len(args.num_vectors): + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") if args.sequential: @@ -491,6 +488,9 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.train_data_template): raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + + if args.num_vectors is None: + args.num_vectors = [None] * len(args.placeholder_tokens) else: if isinstance(args.train_data_template, list): raise ValueError("--train_data_template can't be a list in simultaneous mode") diff --git a/training/functional.py b/training/functional.py index 27a43c2..4565612 100644 --- a/training/functional.py +++ b/training/functional.py @@ -231,12 +231,16 @@ def add_placeholder_tokens( embeddings: ManagedCLIPTextEmbeddings, placeholder_tokens: list[str], initializer_tokens: list[str], - num_vectors: Union[list[int], int] + num_vectors: Optional[Union[list[int], int]] = None, ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in initializer_tokens ] + + if num_vectors is None: + num_vectors = [len(ids) for ids in initializer_token_ids] + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) embeddings.resize(len(tokenizer)) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ccec215..cab5e4c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -11,10 +11,7 @@ from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.loaders import AttnProcsLayers -from slugify import slugify - from models.clip.tokenizer import MultiCLIPTokenizer -from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples @@ -41,16 +38,9 @@ def lora_strategy_callbacks( sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) - weight_dtype = torch.float32 - if accelerator.state.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.state.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - save_samples_ = partial( save_samples, accelerator=accelerator, - unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, @@ -83,20 +73,27 @@ def lora_strategy_callbacks( yield def on_before_optimize(lr: float, epoch: int): - if accelerator.sync_gradients: - accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) + accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, False) - unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") + unet_.save_attn_procs( + checkpoint_output_dir / f"{step}_{postfix}", + safe_serialization=True + ) del unet_ @torch.no_grad() def on_sample(step): - save_samples_(step=step) + unet_ = accelerator.unwrap_model(unet, False) + save_samples_(step=step, unet=unet_) + del unet_ + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, -- cgit v1.2.3-70-g09d2