From fe3113451fdde72ddccfc71639f0a2a1e146209a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Mar 2023 07:11:51 +0100 Subject: Update --- training/functional.py | 6 +++++- training/strategy/lora.py | 25 +++++++++++-------------- 2 files changed, 16 insertions(+), 15 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 27a43c2..4565612 100644 --- a/training/functional.py +++ b/training/functional.py @@ -231,12 +231,16 @@ def add_placeholder_tokens( embeddings: ManagedCLIPTextEmbeddings, placeholder_tokens: list[str], initializer_tokens: list[str], - num_vectors: Union[list[int], int] + num_vectors: Optional[Union[list[int], int]] = None, ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in initializer_tokens ] + + if num_vectors is None: + num_vectors = [len(ids) for ids in initializer_token_ids] + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) embeddings.resize(len(tokenizer)) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ccec215..cab5e4c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -11,10 +11,7 @@ from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.loaders import AttnProcsLayers -from slugify import slugify - from models.clip.tokenizer import MultiCLIPTokenizer -from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples @@ -41,16 +38,9 @@ def lora_strategy_callbacks( sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) - weight_dtype = torch.float32 - if accelerator.state.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.state.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - save_samples_ = partial( save_samples, accelerator=accelerator, - unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, @@ -83,20 +73,27 @@ def lora_strategy_callbacks( yield def on_before_optimize(lr: float, epoch: int): - if accelerator.sync_gradients: - accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) + accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet, False) - unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") + unet_.save_attn_procs( + checkpoint_output_dir / f"{step}_{postfix}", + safe_serialization=True + ) del unet_ @torch.no_grad() def on_sample(step): - save_samples_(step=step) + unet_ = accelerator.unwrap_model(unet, False) + save_samples_(step=step, unet=unet_) + del unet_ + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, -- cgit v1.2.3-54-g00ecf