diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 22:13:55 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 22:13:55 +0200 |
commit | 208e48134e324e934ad964bdc61880cc923f4c0d (patch) | |
tree | c215f6c201c04b0b2d18ba0df230fb4c5e622985 /training/strategy | |
parent | Fix (diff) | |
download | textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.gz textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.tar.bz2 textual-inversion-diff-208e48134e324e934ad964bdc61880cc923f4c0d.zip |
Revert
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 100 |
1 files changed, 81 insertions, 19 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 1b5adab..677f5a3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,6 +1,6 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | from functools import partial | 2 | from functools import partial |
3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager, nullcontext |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | 5 | ||
6 | import torch | 6 | import torch |
@@ -13,6 +13,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
13 | from slugify import slugify | 13 | from slugify import slugify |
14 | 14 | ||
15 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer |
16 | from training.util import EMAModel | ||
16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
17 | 18 | ||
18 | 19 | ||
@@ -31,6 +32,13 @@ def textual_inversion_strategy_callbacks( | |||
31 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
32 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
33 | gradient_checkpointing: bool = False, | 34 | gradient_checkpointing: bool = False, |
35 | use_emb_decay: bool = False, | ||
36 | emb_decay_target: float = 0.4, | ||
37 | emb_decay: float = 1e-2, | ||
38 | use_ema: bool = False, | ||
39 | ema_inv_gamma: float = 1.0, | ||
40 | ema_power: int = 1, | ||
41 | ema_max_decay: float = 0.9999, | ||
34 | sample_batch_size: int = 1, | 42 | sample_batch_size: int = 1, |
35 | sample_num_batches: int = 1, | 43 | sample_num_batches: int = 1, |
36 | sample_num_steps: int = 20, | 44 | sample_num_steps: int = 20, |
@@ -63,8 +71,27 @@ def textual_inversion_strategy_callbacks( | |||
63 | image_size=sample_image_size, | 71 | image_size=sample_image_size, |
64 | ) | 72 | ) |
65 | 73 | ||
74 | if use_ema: | ||
75 | ema_embeddings = EMAModel( | ||
76 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
77 | inv_gamma=ema_inv_gamma, | ||
78 | power=ema_power, | ||
79 | max_value=ema_max_decay, | ||
80 | ) | ||
81 | ema_embeddings.to(accelerator.device) | ||
82 | else: | ||
83 | ema_embeddings = None | ||
84 | |||
85 | def ema_context(): | ||
86 | if ema_embeddings is not None: | ||
87 | return ema_embeddings.apply_temporary( | ||
88 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
89 | ) | ||
90 | else: | ||
91 | return nullcontext() | ||
92 | |||
66 | def on_accum_model(): | 93 | def on_accum_model(): |
67 | return text_encoder.text_model.embeddings | 94 | return text_encoder.text_model.embeddings.temp_token_embedding |
68 | 95 | ||
69 | @contextmanager | 96 | @contextmanager |
70 | def on_train(epoch: int): | 97 | def on_train(epoch: int): |
@@ -74,36 +101,68 @@ def textual_inversion_strategy_callbacks( | |||
74 | @contextmanager | 101 | @contextmanager |
75 | def on_eval(): | 102 | def on_eval(): |
76 | tokenizer.eval() | 103 | tokenizer.eval() |
77 | yield | 104 | |
105 | with ema_context(): | ||
106 | yield | ||
107 | |||
108 | @torch.no_grad() | ||
109 | def on_before_optimize(lr: float, epoch: int): | ||
110 | if use_emb_decay: | ||
111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
112 | return torch.all(w.grad == 0, dim=1) | ||
113 | |||
114 | @torch.no_grad() | ||
115 | def on_after_optimize(zero_ids, lr: float): | ||
116 | if ema_embeddings is not None: | ||
117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
118 | |||
119 | if use_emb_decay: | ||
120 | lambda_ = emb_decay * lr | ||
121 | |||
122 | if lambda_ != 0: | ||
123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
124 | |||
125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
126 | mask[zero_ids] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
130 | |||
131 | def on_log(): | ||
132 | if ema_embeddings is not None: | ||
133 | return {"ema_decay": ema_embeddings.decay} | ||
134 | return {} | ||
78 | 135 | ||
79 | @torch.no_grad() | 136 | @torch.no_grad() |
80 | def on_checkpoint(step, postfix): | 137 | def on_checkpoint(step, postfix): |
81 | print(f"Saving checkpoint for step {step}...") | 138 | print(f"Saving checkpoint for step {step}...") |
82 | 139 | ||
83 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 140 | with ema_context(): |
84 | text_encoder.text_model.embeddings.save_embed( | 141 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
85 | ids, | 142 | text_encoder.text_model.embeddings.save_embed( |
86 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 143 | ids, |
87 | ) | 144 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
145 | ) | ||
88 | 146 | ||
89 | @torch.no_grad() | 147 | @torch.no_grad() |
90 | def on_sample(step): | 148 | def on_sample(step): |
91 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 149 | with ema_context(): |
92 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | ||
93 | 152 | ||
94 | orig_unet_dtype = unet_.dtype | 153 | orig_unet_dtype = unet_.dtype |
95 | orig_text_encoder_dtype = text_encoder_.dtype | 154 | orig_text_encoder_dtype = text_encoder_.dtype |
96 | 155 | ||
97 | unet_.to(dtype=weight_dtype) | 156 | unet_.to(dtype=weight_dtype) |
98 | text_encoder_.to(dtype=weight_dtype) | 157 | text_encoder_.to(dtype=weight_dtype) |
99 | 158 | ||
100 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 159 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
101 | 160 | ||
102 | unet_.to(dtype=orig_unet_dtype) | 161 | unet_.to(dtype=orig_unet_dtype) |
103 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 162 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
104 | 163 | ||
105 | del unet_ | 164 | del unet_ |
106 | del text_encoder_ | 165 | del text_encoder_ |
107 | 166 | ||
108 | if torch.cuda.is_available(): | 167 | if torch.cuda.is_available(): |
109 | torch.cuda.empty_cache() | 168 | torch.cuda.empty_cache() |
@@ -112,6 +171,9 @@ def textual_inversion_strategy_callbacks( | |||
112 | on_accum_model=on_accum_model, | 171 | on_accum_model=on_accum_model, |
113 | on_train=on_train, | 172 | on_train=on_train, |
114 | on_eval=on_eval, | 173 | on_eval=on_eval, |
174 | on_before_optimize=on_before_optimize, | ||
175 | on_after_optimize=on_after_optimize, | ||
176 | on_log=on_log, | ||
115 | on_checkpoint=on_checkpoint, | 177 | on_checkpoint=on_checkpoint, |
116 | on_sample=on_sample, | 178 | on_sample=on_sample, |
117 | ) | 179 | ) |