summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/ti.py100
1 files changed, 81 insertions, 19 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 1b5adab..677f5a3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,6 +1,6 @@
1from typing import Optional 1from typing import Optional
2from functools import partial 2from functools import partial
3from contextlib import contextmanager 3from contextlib import contextmanager, nullcontext
4from pathlib import Path 4from pathlib import Path
5 5
6import torch 6import torch
@@ -13,6 +13,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
13from slugify import slugify 13from slugify import slugify
14 14
15from models.clip.tokenizer import MultiCLIPTokenizer 15from models.clip.tokenizer import MultiCLIPTokenizer
16from training.util import EMAModel
16from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
17 18
18 19
@@ -31,6 +32,13 @@ def textual_inversion_strategy_callbacks(
31 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
32 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
33 gradient_checkpointing: bool = False, 34 gradient_checkpointing: bool = False,
35 use_emb_decay: bool = False,
36 emb_decay_target: float = 0.4,
37 emb_decay: float = 1e-2,
38 use_ema: bool = False,
39 ema_inv_gamma: float = 1.0,
40 ema_power: int = 1,
41 ema_max_decay: float = 0.9999,
34 sample_batch_size: int = 1, 42 sample_batch_size: int = 1,
35 sample_num_batches: int = 1, 43 sample_num_batches: int = 1,
36 sample_num_steps: int = 20, 44 sample_num_steps: int = 20,
@@ -63,8 +71,27 @@ def textual_inversion_strategy_callbacks(
63 image_size=sample_image_size, 71 image_size=sample_image_size,
64 ) 72 )
65 73
74 if use_ema:
75 ema_embeddings = EMAModel(
76 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
77 inv_gamma=ema_inv_gamma,
78 power=ema_power,
79 max_value=ema_max_decay,
80 )
81 ema_embeddings.to(accelerator.device)
82 else:
83 ema_embeddings = None
84
85 def ema_context():
86 if ema_embeddings is not None:
87 return ema_embeddings.apply_temporary(
88 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
89 )
90 else:
91 return nullcontext()
92
66 def on_accum_model(): 93 def on_accum_model():
67 return text_encoder.text_model.embeddings 94 return text_encoder.text_model.embeddings.temp_token_embedding
68 95
69 @contextmanager 96 @contextmanager
70 def on_train(epoch: int): 97 def on_train(epoch: int):
@@ -74,36 +101,68 @@ def textual_inversion_strategy_callbacks(
74 @contextmanager 101 @contextmanager
75 def on_eval(): 102 def on_eval():
76 tokenizer.eval() 103 tokenizer.eval()
77 yield 104
105 with ema_context():
106 yield
107
108 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 return torch.all(w.grad == 0, dim=1)
113
114 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float):
116 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
118
119 if use_emb_decay:
120 lambda_ = emb_decay * lr
121
122 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130
131 def on_log():
132 if ema_embeddings is not None:
133 return {"ema_decay": ema_embeddings.decay}
134 return {}
78 135
79 @torch.no_grad() 136 @torch.no_grad()
80 def on_checkpoint(step, postfix): 137 def on_checkpoint(step, postfix):
81 print(f"Saving checkpoint for step {step}...") 138 print(f"Saving checkpoint for step {step}...")
82 139
83 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 140 with ema_context():
84 text_encoder.text_model.embeddings.save_embed( 141 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
85 ids, 142 text_encoder.text_model.embeddings.save_embed(
86 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 143 ids,
87 ) 144 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
145 )
88 146
89 @torch.no_grad() 147 @torch.no_grad()
90 def on_sample(step): 148 def on_sample(step):
91 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 149 with ema_context():
92 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 150 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
151 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
93 152
94 orig_unet_dtype = unet_.dtype 153 orig_unet_dtype = unet_.dtype
95 orig_text_encoder_dtype = text_encoder_.dtype 154 orig_text_encoder_dtype = text_encoder_.dtype
96 155
97 unet_.to(dtype=weight_dtype) 156 unet_.to(dtype=weight_dtype)
98 text_encoder_.to(dtype=weight_dtype) 157 text_encoder_.to(dtype=weight_dtype)
99 158
100 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 159 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
101 160
102 unet_.to(dtype=orig_unet_dtype) 161 unet_.to(dtype=orig_unet_dtype)
103 text_encoder_.to(dtype=orig_text_encoder_dtype) 162 text_encoder_.to(dtype=orig_text_encoder_dtype)
104 163
105 del unet_ 164 del unet_
106 del text_encoder_ 165 del text_encoder_
107 166
108 if torch.cuda.is_available(): 167 if torch.cuda.is_available():
109 torch.cuda.empty_cache() 168 torch.cuda.empty_cache()
@@ -112,6 +171,9 @@ def textual_inversion_strategy_callbacks(
112 on_accum_model=on_accum_model, 171 on_accum_model=on_accum_model,
113 on_train=on_train, 172 on_train=on_train,
114 on_eval=on_eval, 173 on_eval=on_eval,
174 on_before_optimize=on_before_optimize,
175 on_after_optimize=on_after_optimize,
176 on_log=on_log,
115 on_checkpoint=on_checkpoint, 177 on_checkpoint=on_checkpoint,
116 on_sample=on_sample, 178 on_sample=on_sample,
117 ) 179 )