diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 | 
| commit | fe3113451fdde72ddccfc71639f0a2a1e146209a (patch) | |
| tree | ba4114faf1bd00a642f97b5e7729ad74213c3b80 /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2 textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip  | |
Update
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 6 | ||||
| -rw-r--r-- | training/strategy/lora.py | 25 | 
2 files changed, 16 insertions, 15 deletions
diff --git a/training/functional.py b/training/functional.py index 27a43c2..4565612 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -231,12 +231,16 @@ def add_placeholder_tokens( | |||
| 231 | embeddings: ManagedCLIPTextEmbeddings, | 231 | embeddings: ManagedCLIPTextEmbeddings, | 
| 232 | placeholder_tokens: list[str], | 232 | placeholder_tokens: list[str], | 
| 233 | initializer_tokens: list[str], | 233 | initializer_tokens: list[str], | 
| 234 | num_vectors: Union[list[int], int] | 234 | num_vectors: Optional[Union[list[int], int]] = None, | 
| 235 | ): | 235 | ): | 
| 236 | initializer_token_ids = [ | 236 | initializer_token_ids = [ | 
| 237 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) | 
| 238 | for token in initializer_tokens | 238 | for token in initializer_tokens | 
| 239 | ] | 239 | ] | 
| 240 | |||
| 241 | if num_vectors is None: | ||
| 242 | num_vectors = [len(ids) for ids in initializer_token_ids] | ||
| 243 | |||
| 240 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | 244 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | 
| 241 | 245 | ||
| 242 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) | 
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ccec215..cab5e4c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py  | |||
| @@ -11,10 +11,7 @@ from transformers import CLIPTextModel | |||
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 
| 12 | from diffusers.loaders import AttnProcsLayers | 12 | from diffusers.loaders import AttnProcsLayers | 
| 13 | 13 | ||
| 14 | from slugify import slugify | ||
| 15 | |||
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 14 | from models.clip.tokenizer import MultiCLIPTokenizer | 
| 17 | from training.util import EMAModel | ||
| 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 15 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 
| 19 | 16 | ||
| 20 | 17 | ||
| @@ -41,16 +38,9 @@ def lora_strategy_callbacks( | |||
| 41 | sample_output_dir.mkdir(parents=True, exist_ok=True) | 38 | sample_output_dir.mkdir(parents=True, exist_ok=True) | 
| 42 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | 39 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | 
| 43 | 40 | ||
| 44 | weight_dtype = torch.float32 | ||
| 45 | if accelerator.state.mixed_precision == "fp16": | ||
| 46 | weight_dtype = torch.float16 | ||
| 47 | elif accelerator.state.mixed_precision == "bf16": | ||
| 48 | weight_dtype = torch.bfloat16 | ||
| 49 | |||
| 50 | save_samples_ = partial( | 41 | save_samples_ = partial( | 
| 51 | save_samples, | 42 | save_samples, | 
| 52 | accelerator=accelerator, | 43 | accelerator=accelerator, | 
| 53 | unet=unet, | ||
| 54 | text_encoder=text_encoder, | 44 | text_encoder=text_encoder, | 
| 55 | tokenizer=tokenizer, | 45 | tokenizer=tokenizer, | 
| 56 | vae=vae, | 46 | vae=vae, | 
| @@ -83,20 +73,27 @@ def lora_strategy_callbacks( | |||
| 83 | yield | 73 | yield | 
| 84 | 74 | ||
| 85 | def on_before_optimize(lr: float, epoch: int): | 75 | def on_before_optimize(lr: float, epoch: int): | 
| 86 | if accelerator.sync_gradients: | 76 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | 
| 87 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | ||
| 88 | 77 | ||
| 89 | @torch.no_grad() | 78 | @torch.no_grad() | 
| 90 | def on_checkpoint(step, postfix): | 79 | def on_checkpoint(step, postfix): | 
| 91 | print(f"Saving checkpoint for step {step}...") | 80 | print(f"Saving checkpoint for step {step}...") | 
| 92 | 81 | ||
| 93 | unet_ = accelerator.unwrap_model(unet, False) | 82 | unet_ = accelerator.unwrap_model(unet, False) | 
| 94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") | 83 | unet_.save_attn_procs( | 
| 84 | checkpoint_output_dir / f"{step}_{postfix}", | ||
| 85 | safe_serialization=True | ||
| 86 | ) | ||
| 95 | del unet_ | 87 | del unet_ | 
| 96 | 88 | ||
| 97 | @torch.no_grad() | 89 | @torch.no_grad() | 
| 98 | def on_sample(step): | 90 | def on_sample(step): | 
| 99 | save_samples_(step=step) | 91 | unet_ = accelerator.unwrap_model(unet, False) | 
| 92 | save_samples_(step=step, unet=unet_) | ||
| 93 | del unet_ | ||
| 94 | |||
| 95 | if torch.cuda.is_available(): | ||
| 96 | torch.cuda.empty_cache() | ||
| 100 | 97 | ||
| 101 | return TrainingCallbacks( | 98 | return TrainingCallbacks( | 
| 102 | on_prepare=on_prepare, | 99 | on_prepare=on_prepare, | 
