diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 2 | ||||
| -rw-r--r-- | training/strategy/ti.py | 100 | 
2 files changed, 82 insertions, 20 deletions
| diff --git a/training/functional.py b/training/functional.py index 7104a88..bd8cbad 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -524,7 +524,7 @@ def train_loop( | |||
| 524 | 524 | ||
| 525 | lr = lr_scheduler.get_last_lr()[0] | 525 | lr = lr_scheduler.get_last_lr()[0] | 
| 526 | if torch.is_tensor(lr): | 526 | if torch.is_tensor(lr): | 
| 527 | lr = lr.item | 527 | lr = lr.item() | 
| 528 | 528 | ||
| 529 | lrs.append(lr) | 529 | lrs.append(lr) | 
| 530 | 530 | ||
| diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 1b5adab..677f5a3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional | 
| 2 | from functools import partial | 2 | from functools import partial | 
| 3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager, nullcontext | 
| 4 | from pathlib import Path | 4 | from pathlib import Path | 
| 5 | 5 | ||
| 6 | import torch | 6 | import torch | 
| @@ -13,6 +13,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 13 | from slugify import slugify | 13 | from slugify import slugify | 
| 14 | 14 | ||
| 15 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer | 
| 16 | from training.util import EMAModel | ||
| 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 
| 17 | 18 | ||
| 18 | 19 | ||
| @@ -31,6 +32,13 @@ def textual_inversion_strategy_callbacks( | |||
| 31 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], | 
| 32 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], | 
| 33 | gradient_checkpointing: bool = False, | 34 | gradient_checkpointing: bool = False, | 
| 35 | use_emb_decay: bool = False, | ||
| 36 | emb_decay_target: float = 0.4, | ||
| 37 | emb_decay: float = 1e-2, | ||
| 38 | use_ema: bool = False, | ||
| 39 | ema_inv_gamma: float = 1.0, | ||
| 40 | ema_power: int = 1, | ||
| 41 | ema_max_decay: float = 0.9999, | ||
| 34 | sample_batch_size: int = 1, | 42 | sample_batch_size: int = 1, | 
| 35 | sample_num_batches: int = 1, | 43 | sample_num_batches: int = 1, | 
| 36 | sample_num_steps: int = 20, | 44 | sample_num_steps: int = 20, | 
| @@ -63,8 +71,27 @@ def textual_inversion_strategy_callbacks( | |||
| 63 | image_size=sample_image_size, | 71 | image_size=sample_image_size, | 
| 64 | ) | 72 | ) | 
| 65 | 73 | ||
| 74 | if use_ema: | ||
| 75 | ema_embeddings = EMAModel( | ||
| 76 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 77 | inv_gamma=ema_inv_gamma, | ||
| 78 | power=ema_power, | ||
| 79 | max_value=ema_max_decay, | ||
| 80 | ) | ||
| 81 | ema_embeddings.to(accelerator.device) | ||
| 82 | else: | ||
| 83 | ema_embeddings = None | ||
| 84 | |||
| 85 | def ema_context(): | ||
| 86 | if ema_embeddings is not None: | ||
| 87 | return ema_embeddings.apply_temporary( | ||
| 88 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 89 | ) | ||
| 90 | else: | ||
| 91 | return nullcontext() | ||
| 92 | |||
| 66 | def on_accum_model(): | 93 | def on_accum_model(): | 
| 67 | return text_encoder.text_model.embeddings | 94 | return text_encoder.text_model.embeddings.temp_token_embedding | 
| 68 | 95 | ||
| 69 | @contextmanager | 96 | @contextmanager | 
| 70 | def on_train(epoch: int): | 97 | def on_train(epoch: int): | 
| @@ -74,36 +101,68 @@ def textual_inversion_strategy_callbacks( | |||
| 74 | @contextmanager | 101 | @contextmanager | 
| 75 | def on_eval(): | 102 | def on_eval(): | 
| 76 | tokenizer.eval() | 103 | tokenizer.eval() | 
| 77 | yield | 104 | |
| 105 | with ema_context(): | ||
| 106 | yield | ||
| 107 | |||
| 108 | @torch.no_grad() | ||
| 109 | def on_before_optimize(lr: float, epoch: int): | ||
| 110 | if use_emb_decay: | ||
| 111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 112 | return torch.all(w.grad == 0, dim=1) | ||
| 113 | |||
| 114 | @torch.no_grad() | ||
| 115 | def on_after_optimize(zero_ids, lr: float): | ||
| 116 | if ema_embeddings is not None: | ||
| 117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 118 | |||
| 119 | if use_emb_decay: | ||
| 120 | lambda_ = emb_decay * lr | ||
| 121 | |||
| 122 | if lambda_ != 0: | ||
| 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 124 | |||
| 125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
| 126 | mask[zero_ids] = False | ||
| 127 | |||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 130 | |||
| 131 | def on_log(): | ||
| 132 | if ema_embeddings is not None: | ||
| 133 | return {"ema_decay": ema_embeddings.decay} | ||
| 134 | return {} | ||
| 78 | 135 | ||
| 79 | @torch.no_grad() | 136 | @torch.no_grad() | 
| 80 | def on_checkpoint(step, postfix): | 137 | def on_checkpoint(step, postfix): | 
| 81 | print(f"Saving checkpoint for step {step}...") | 138 | print(f"Saving checkpoint for step {step}...") | 
| 82 | 139 | ||
| 83 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 140 | with ema_context(): | 
| 84 | text_encoder.text_model.embeddings.save_embed( | 141 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 
| 85 | ids, | 142 | text_encoder.text_model.embeddings.save_embed( | 
| 86 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 143 | ids, | 
| 87 | ) | 144 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 
| 145 | ) | ||
| 88 | 146 | ||
| 89 | @torch.no_grad() | 147 | @torch.no_grad() | 
| 90 | def on_sample(step): | 148 | def on_sample(step): | 
| 91 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 149 | with ema_context(): | 
| 92 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 
| 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | ||
| 93 | 152 | ||
| 94 | orig_unet_dtype = unet_.dtype | 153 | orig_unet_dtype = unet_.dtype | 
| 95 | orig_text_encoder_dtype = text_encoder_.dtype | 154 | orig_text_encoder_dtype = text_encoder_.dtype | 
| 96 | 155 | ||
| 97 | unet_.to(dtype=weight_dtype) | 156 | unet_.to(dtype=weight_dtype) | 
| 98 | text_encoder_.to(dtype=weight_dtype) | 157 | text_encoder_.to(dtype=weight_dtype) | 
| 99 | 158 | ||
| 100 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 159 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 
| 101 | 160 | ||
| 102 | unet_.to(dtype=orig_unet_dtype) | 161 | unet_.to(dtype=orig_unet_dtype) | 
| 103 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 162 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 
| 104 | 163 | ||
| 105 | del unet_ | 164 | del unet_ | 
| 106 | del text_encoder_ | 165 | del text_encoder_ | 
| 107 | 166 | ||
| 108 | if torch.cuda.is_available(): | 167 | if torch.cuda.is_available(): | 
| 109 | torch.cuda.empty_cache() | 168 | torch.cuda.empty_cache() | 
| @@ -112,6 +171,9 @@ def textual_inversion_strategy_callbacks( | |||
| 112 | on_accum_model=on_accum_model, | 171 | on_accum_model=on_accum_model, | 
| 113 | on_train=on_train, | 172 | on_train=on_train, | 
| 114 | on_eval=on_eval, | 173 | on_eval=on_eval, | 
| 174 | on_before_optimize=on_before_optimize, | ||
| 175 | on_after_optimize=on_after_optimize, | ||
| 176 | on_log=on_log, | ||
| 115 | on_checkpoint=on_checkpoint, | 177 | on_checkpoint=on_checkpoint, | 
| 116 | on_sample=on_sample, | 178 | on_sample=on_sample, | 
| 117 | ) | 179 | ) | 
