summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/ti.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
new file mode 100644
index 0000000..83dc566
--- /dev/null
+++ b/training/strategy/ti.py
@@ -0,0 +1,164 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6
7import torch
8from torch.utils.data import DataLoader
9
10from accelerate import Accelerator
11from transformers import CLIPTextModel
12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
13
14from slugify import slugify
15
16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel
18from training.functional import save_samples
19
20
21def textual_inversion_strategy(
22 accelerator: Accelerator,
23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel,
25 tokenizer: MultiCLIPTokenizer,
26 vae: AutoencoderKL,
27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader,
30 dtype: torch.dtype,
31 output_dir: Path,
32 seed: int,
33 placeholder_tokens: list[str],
34 placeholder_token_ids: list[list[int]],
35 learning_rate: float,
36 gradient_checkpointing: bool = False,
37 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4,
39 emb_decay_factor: float = 1,
40 emb_decay_start: float = 1e-4,
41 use_ema: bool = False,
42 ema_inv_gamma: float = 1.0,
43 ema_power: int = 1,
44 ema_max_decay: float = 0.9999,
45 sample_batch_size: int = 1,
46 sample_num_batches: int = 1,
47 sample_num_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: Optional[int] = None,
50):
51 save_samples_ = partial(
52 save_samples,
53 accelerator=accelerator,
54 unet=unet,
55 text_encoder=text_encoder,
56 tokenizer=tokenizer,
57 vae=vae,
58 sample_scheduler=sample_scheduler,
59 train_dataloader=train_dataloader,
60 val_dataloader=val_dataloader,
61 dtype=dtype,
62 output_dir=output_dir,
63 seed=seed,
64 batch_size=sample_batch_size,
65 num_batches=sample_num_batches,
66 num_steps=sample_num_steps,
67 guidance_scale=sample_guidance_scale,
68 image_size=sample_image_size,
69 )
70
71 if use_ema:
72 ema_embeddings = EMAModel(
73 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
74 inv_gamma=ema_inv_gamma,
75 power=ema_power,
76 max_value=ema_max_decay,
77 )
78 else:
79 ema_embeddings = None
80
81 def on_prepare():
82 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
83
84 if use_ema:
85 ema_embeddings.to(accelerator.device)
86
87 if gradient_checkpointing:
88 unet.train()
89
90 @contextmanager
91 def on_train(epoch: int):
92 try:
93 tokenizer.train()
94 yield
95 finally:
96 pass
97
98 @contextmanager
99 def on_eval():
100 try:
101 tokenizer.eval()
102
103 ema_context = ema_embeddings.apply_temporary(
104 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext()
105
106 with ema_context:
107 yield
108 finally:
109 pass
110
111 @torch.no_grad()
112 def on_after_optimize(lr: float):
113 if use_emb_decay:
114 text_encoder.text_model.embeddings.normalize(
115 emb_decay_target,
116 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
117 )
118
119 if use_ema:
120 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
121
122 def on_log():
123 if use_ema:
124 return {"ema_decay": ema_embeddings.decay}
125 return {}
126
127 @torch.no_grad()
128 def on_checkpoint(step, postfix):
129 print(f"Saving checkpoint for step {step}...")
130
131 checkpoints_path = output_dir.joinpath("checkpoints")
132 checkpoints_path.mkdir(parents=True, exist_ok=True)
133
134 text_encoder = accelerator.unwrap_model(text_encoder)
135
136 ema_context = ema_embeddings.apply_temporary(
137 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
138 ) if ema_embeddings is not None else nullcontext()
139
140 with ema_context:
141 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
142 text_encoder.text_model.embeddings.save_embed(
143 ids,
144 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
145 )
146
147 @torch.no_grad()
148 def on_sample(step):
149 ema_context = ema_embeddings.apply_temporary(
150 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
151 ) if ema_embeddings is not None else nullcontext()
152
153 with ema_context:
154 save_samples_(step=step)
155
156 return {
157 "on_prepare": on_prepare,
158 "on_train": on_train,
159 "on_eval": on_eval,
160 "on_after_optimize": on_after_optimize,
161 "on_log": on_log,
162 "on_checkpoint": on_checkpoint,
163 "on_sample": on_sample,
164 }