diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 22 |
2 files changed, 23 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index 1d8e2ee..96ecbc1 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -73,7 +73,7 @@ def make_grid(images, rows, cols): | |||
| 73 | return grid | 73 | return grid |
| 74 | 74 | ||
| 75 | 75 | ||
| 76 | def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): | 76 | def get_models(pretrained_model_name_or_path: str): |
| 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| @@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): | |||
| 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 83 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
| 84 | 84 | ||
| 85 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha) | 85 | embeddings = patch_managed_embeddings(text_encoder) |
| 86 | 86 | ||
| 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 88 | 88 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 95128da..9df160a 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks( | |||
| 31 | seed: int, | 31 | seed: int, |
| 32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
| 33 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
| 34 | use_emb_decay: bool = False, | ||
| 35 | emb_decay_target: float = 0.4, | ||
| 36 | emb_decay: float = 1e-2, | ||
| 34 | use_ema: bool = False, | 37 | use_ema: bool = False, |
| 35 | ema_inv_gamma: float = 1.0, | 38 | ema_inv_gamma: float = 1.0, |
| 36 | ema_power: int = 1, | 39 | ema_power: int = 1, |
| @@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks( | |||
| 102 | yield | 105 | yield |
| 103 | 106 | ||
| 104 | @torch.no_grad() | 107 | @torch.no_grad() |
| 105 | def on_after_optimize(zero_ids, lr: float): | 108 | def on_before_optimize(lr: float, epoch: int): |
| 109 | if use_emb_decay: | ||
| 110 | return torch.stack([ | ||
| 111 | p | ||
| 112 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | ||
| 113 | if p.grad is not None | ||
| 114 | ]) | ||
| 115 | |||
| 116 | @torch.no_grad() | ||
| 117 | def on_after_optimize(w, lr: float): | ||
| 106 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 107 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
| 108 | 120 | ||
| 121 | if use_emb_decay: | ||
| 122 | lambda_ = emb_decay * lr | ||
| 123 | |||
| 124 | if lambda_ != 0: | ||
| 125 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 126 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 127 | |||
| 109 | def on_log(): | 128 | def on_log(): |
| 110 | if ema_embeddings is not None: | 129 | if ema_embeddings is not None: |
| 111 | return {"ema_decay": ema_embeddings.decay} | 130 | return {"ema_decay": ema_embeddings.decay} |
| @@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks( | |||
| 149 | on_accum_model=on_accum_model, | 168 | on_accum_model=on_accum_model, |
| 150 | on_train=on_train, | 169 | on_train=on_train, |
| 151 | on_eval=on_eval, | 170 | on_eval=on_eval, |
| 171 | on_before_optimize=on_before_optimize, | ||
| 152 | on_after_optimize=on_after_optimize, | 172 | on_after_optimize=on_after_optimize, |
| 153 | on_log=on_log, | 173 | on_log=on_log, |
| 154 | on_checkpoint=on_checkpoint, | 174 | on_checkpoint=on_checkpoint, |
