summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/ti.py76
1 files changed, 18 insertions, 58 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 19b8d25..33f5fb9 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,6 +1,6 @@
1from typing import Optional 1from typing import Optional
2from functools import partial 2from functools import partial
3from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager
4from pathlib import Path 4from pathlib import Path
5 5
6import torch 6import torch
@@ -13,7 +13,6 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
13from slugify import slugify 13from slugify import slugify
14 14
15from models.clip.tokenizer import MultiCLIPTokenizer 15from models.clip.tokenizer import MultiCLIPTokenizer
16from training.util import EMAModel
17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 16from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
18 17
19 18
@@ -32,10 +31,6 @@ def textual_inversion_strategy_callbacks(
32 placeholder_tokens: list[str], 31 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 32 placeholder_token_ids: list[list[int]],
34 gradient_checkpointing: bool = False, 33 gradient_checkpointing: bool = False,
35 use_ema: bool = False,
36 ema_inv_gamma: float = 1.0,
37 ema_power: int = 1,
38 ema_max_decay: float = 0.9999,
39 sample_batch_size: int = 1, 34 sample_batch_size: int = 1,
40 sample_num_batches: int = 1, 35 sample_num_batches: int = 1,
41 sample_num_steps: int = 20, 36 sample_num_steps: int = 20,
@@ -68,25 +63,6 @@ def textual_inversion_strategy_callbacks(
68 image_size=sample_image_size, 63 image_size=sample_image_size,
69 ) 64 )
70 65
71 if use_ema:
72 ema_embeddings = EMAModel(
73 text_encoder.text_model.embeddings.overlay.parameters(),
74 inv_gamma=ema_inv_gamma,
75 power=ema_power,
76 max_value=ema_max_decay,
77 )
78 ema_embeddings.to(accelerator.device)
79 else:
80 ema_embeddings = None
81
82 def ema_context():
83 if ema_embeddings is not None:
84 return ema_embeddings.apply_temporary(
85 text_encoder.text_model.embeddings.overlay.parameters()
86 )
87 else:
88 return nullcontext()
89
90 def on_accum_model(): 66 def on_accum_model():
91 return text_encoder.text_model.embeddings.overlay 67 return text_encoder.text_model.embeddings.overlay
92 68
@@ -98,50 +74,36 @@ def textual_inversion_strategy_callbacks(
98 @contextmanager 74 @contextmanager
99 def on_eval(): 75 def on_eval():
100 tokenizer.eval() 76 tokenizer.eval()
101 77 yield
102 with ema_context():
103 yield
104
105 @torch.no_grad()
106 def on_after_optimize(zero_ids, lr: float):
107 if ema_embeddings is not None:
108 ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters())
109
110 def on_log():
111 if ema_embeddings is not None:
112 return {"ema_decay": ema_embeddings.decay}
113 return {}
114 78
115 @torch.no_grad() 79 @torch.no_grad()
116 def on_checkpoint(step, postfix): 80 def on_checkpoint(step, postfix):
117 print(f"Saving checkpoint for step {step}...") 81 print(f"Saving checkpoint for step {step}...")
118 82
119 with ema_context(): 83 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
120 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 84 text_encoder.text_model.embeddings.save_embed(
121 text_encoder.text_model.embeddings.save_embed( 85 ids,
122 ids, 86 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
123 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 87 )
124 )
125 88
126 @torch.no_grad() 89 @torch.no_grad()
127 def on_sample(step): 90 def on_sample(step):
128 with ema_context(): 91 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
129 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 92 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
130 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
131 93
132 orig_unet_dtype = unet_.dtype 94 orig_unet_dtype = unet_.dtype
133 orig_text_encoder_dtype = text_encoder_.dtype 95 orig_text_encoder_dtype = text_encoder_.dtype
134 96
135 unet_.to(dtype=weight_dtype) 97 unet_.to(dtype=weight_dtype)
136 text_encoder_.to(dtype=weight_dtype) 98 text_encoder_.to(dtype=weight_dtype)
137 99
138 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 100 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
139 101
140 unet_.to(dtype=orig_unet_dtype) 102 unet_.to(dtype=orig_unet_dtype)
141 text_encoder_.to(dtype=orig_text_encoder_dtype) 103 text_encoder_.to(dtype=orig_text_encoder_dtype)
142 104
143 del unet_ 105 del unet_
144 del text_encoder_ 106 del text_encoder_
145 107
146 if torch.cuda.is_available(): 108 if torch.cuda.is_available():
147 torch.cuda.empty_cache() 109 torch.cuda.empty_cache()
@@ -150,8 +112,6 @@ def textual_inversion_strategy_callbacks(
150 on_accum_model=on_accum_model, 112 on_accum_model=on_accum_model,
151 on_train=on_train, 113 on_train=on_train,
152 on_eval=on_eval, 114 on_eval=on_eval,
153 on_after_optimize=on_after_optimize,
154 on_log=on_log,
155 on_checkpoint=on_checkpoint, 115 on_checkpoint=on_checkpoint,
156 on_sample=on_sample, 116 on_sample=on_sample,
157 ) 117 )