diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/strategy/ti.py | 76 |
1 files changed, 18 insertions, 58 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 19b8d25..33f5fb9 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,6 +1,6 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | from functools import partial | 2 | from functools import partial |
3 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | 5 | ||
6 | import torch | 6 | import torch |
@@ -13,7 +13,6 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
13 | from slugify import slugify | 13 | from slugify import slugify |
14 | 14 | ||
15 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer |
16 | from training.util import EMAModel | ||
17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
18 | 17 | ||
19 | 18 | ||
@@ -32,10 +31,6 @@ def textual_inversion_strategy_callbacks( | |||
32 | placeholder_tokens: list[str], | 31 | placeholder_tokens: list[str], |
33 | placeholder_token_ids: list[list[int]], | 32 | placeholder_token_ids: list[list[int]], |
34 | gradient_checkpointing: bool = False, | 33 | gradient_checkpointing: bool = False, |
35 | use_ema: bool = False, | ||
36 | ema_inv_gamma: float = 1.0, | ||
37 | ema_power: int = 1, | ||
38 | ema_max_decay: float = 0.9999, | ||
39 | sample_batch_size: int = 1, | 34 | sample_batch_size: int = 1, |
40 | sample_num_batches: int = 1, | 35 | sample_num_batches: int = 1, |
41 | sample_num_steps: int = 20, | 36 | sample_num_steps: int = 20, |
@@ -68,25 +63,6 @@ def textual_inversion_strategy_callbacks( | |||
68 | image_size=sample_image_size, | 63 | image_size=sample_image_size, |
69 | ) | 64 | ) |
70 | 65 | ||
71 | if use_ema: | ||
72 | ema_embeddings = EMAModel( | ||
73 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
74 | inv_gamma=ema_inv_gamma, | ||
75 | power=ema_power, | ||
76 | max_value=ema_max_decay, | ||
77 | ) | ||
78 | ema_embeddings.to(accelerator.device) | ||
79 | else: | ||
80 | ema_embeddings = None | ||
81 | |||
82 | def ema_context(): | ||
83 | if ema_embeddings is not None: | ||
84 | return ema_embeddings.apply_temporary( | ||
85 | text_encoder.text_model.embeddings.overlay.parameters() | ||
86 | ) | ||
87 | else: | ||
88 | return nullcontext() | ||
89 | |||
90 | def on_accum_model(): | 66 | def on_accum_model(): |
91 | return text_encoder.text_model.embeddings.overlay | 67 | return text_encoder.text_model.embeddings.overlay |
92 | 68 | ||
@@ -98,50 +74,36 @@ def textual_inversion_strategy_callbacks( | |||
98 | @contextmanager | 74 | @contextmanager |
99 | def on_eval(): | 75 | def on_eval(): |
100 | tokenizer.eval() | 76 | tokenizer.eval() |
101 | 77 | yield | |
102 | with ema_context(): | ||
103 | yield | ||
104 | |||
105 | @torch.no_grad() | ||
106 | def on_after_optimize(zero_ids, lr: float): | ||
107 | if ema_embeddings is not None: | ||
108 | ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) | ||
109 | |||
110 | def on_log(): | ||
111 | if ema_embeddings is not None: | ||
112 | return {"ema_decay": ema_embeddings.decay} | ||
113 | return {} | ||
114 | 78 | ||
115 | @torch.no_grad() | 79 | @torch.no_grad() |
116 | def on_checkpoint(step, postfix): | 80 | def on_checkpoint(step, postfix): |
117 | print(f"Saving checkpoint for step {step}...") | 81 | print(f"Saving checkpoint for step {step}...") |
118 | 82 | ||
119 | with ema_context(): | 83 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
120 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 84 | text_encoder.text_model.embeddings.save_embed( |
121 | text_encoder.text_model.embeddings.save_embed( | 85 | ids, |
122 | ids, | 86 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
123 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 87 | ) |
124 | ) | ||
125 | 88 | ||
126 | @torch.no_grad() | 89 | @torch.no_grad() |
127 | def on_sample(step): | 90 | def on_sample(step): |
128 | with ema_context(): | 91 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 92 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | ||
131 | 93 | ||
132 | orig_unet_dtype = unet_.dtype | 94 | orig_unet_dtype = unet_.dtype |
133 | orig_text_encoder_dtype = text_encoder_.dtype | 95 | orig_text_encoder_dtype = text_encoder_.dtype |
134 | 96 | ||
135 | unet_.to(dtype=weight_dtype) | 97 | unet_.to(dtype=weight_dtype) |
136 | text_encoder_.to(dtype=weight_dtype) | 98 | text_encoder_.to(dtype=weight_dtype) |
137 | 99 | ||
138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 100 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
139 | 101 | ||
140 | unet_.to(dtype=orig_unet_dtype) | 102 | unet_.to(dtype=orig_unet_dtype) |
141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 103 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
142 | 104 | ||
143 | del unet_ | 105 | del unet_ |
144 | del text_encoder_ | 106 | del text_encoder_ |
145 | 107 | ||
146 | if torch.cuda.is_available(): | 108 | if torch.cuda.is_available(): |
147 | torch.cuda.empty_cache() | 109 | torch.cuda.empty_cache() |
@@ -150,8 +112,6 @@ def textual_inversion_strategy_callbacks( | |||
150 | on_accum_model=on_accum_model, | 112 | on_accum_model=on_accum_model, |
151 | on_train=on_train, | 113 | on_train=on_train, |
152 | on_eval=on_eval, | 114 | on_eval=on_eval, |
153 | on_after_optimize=on_after_optimize, | ||
154 | on_log=on_log, | ||
155 | on_checkpoint=on_checkpoint, | 115 | on_checkpoint=on_checkpoint, |
156 | on_sample=on_sample, | 116 | on_sample=on_sample, |
157 | ) | 117 | ) |