diff options
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/dreambooth.py | 10 | ||||
| -rw-r--r-- | training/strategy/ti.py | 19 |
2 files changed, 21 insertions, 8 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
| 18 | from training.functional import TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | def dreambooth_strategy( | 21 | def dreambooth_strategy_callbacks( |
| 22 | accelerator: Accelerator, | 22 | accelerator: Accelerator, |
| 23 | unet: UNet2DConditionModel, | 23 | unet: UNet2DConditionModel, |
| 24 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
| @@ -185,3 +185,9 @@ def dreambooth_strategy( | |||
| 185 | on_checkpoint=on_checkpoint, | 185 | on_checkpoint=on_checkpoint, |
| 186 | on_sample=on_sample, | 186 | on_sample=on_sample, |
| 187 | ) | 187 | ) |
| 188 | |||
| 189 | |||
| 190 | dreambooth_strategy = TrainingStrategy( | ||
| 191 | callbacks=dreambooth_strategy_callbacks, | ||
| 192 | prepare_unet=True | ||
| 193 | ) | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 00f3529..597abd0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -15,10 +15,10 @@ from slugify import slugify | |||
| 15 | 15 | ||
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
| 18 | from training.functional import TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | def textual_inversion_strategy( | 21 | def textual_inversion_strategy_callbacks( |
| 22 | accelerator: Accelerator, | 22 | accelerator: Accelerator, |
| 23 | unet: UNet2DConditionModel, | 23 | unet: UNet2DConditionModel, |
| 24 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
| @@ -119,17 +119,18 @@ def textual_inversion_strategy( | |||
| 119 | with ema_context(): | 119 | with ema_context(): |
| 120 | yield | 120 | yield |
| 121 | 121 | ||
| 122 | @torch.no_grad() | ||
| 123 | def on_after_optimize(lr: float): | 122 | def on_after_optimize(lr: float): |
| 123 | if use_ema: | ||
| 124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 125 | |||
| 126 | @torch.no_grad() | ||
| 127 | def on_after_epoch(lr: float): | ||
| 124 | if use_emb_decay: | 128 | if use_emb_decay: |
| 125 | text_encoder.text_model.embeddings.normalize( | 129 | text_encoder.text_model.embeddings.normalize( |
| 126 | emb_decay_target, | 130 | emb_decay_target, |
| 127 | min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) | 131 | min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) |
| 128 | ) | 132 | ) |
| 129 | 133 | ||
| 130 | if use_ema: | ||
| 131 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 132 | |||
| 133 | def on_log(): | 134 | def on_log(): |
| 134 | if use_ema: | 135 | if use_ema: |
| 135 | return {"ema_decay": ema_embeddings.decay} | 136 | return {"ema_decay": ema_embeddings.decay} |
| @@ -157,7 +158,13 @@ def textual_inversion_strategy( | |||
| 157 | on_train=on_train, | 158 | on_train=on_train, |
| 158 | on_eval=on_eval, | 159 | on_eval=on_eval, |
| 159 | on_after_optimize=on_after_optimize, | 160 | on_after_optimize=on_after_optimize, |
| 161 | on_after_epoch=on_after_epoch, | ||
| 160 | on_log=on_log, | 162 | on_log=on_log, |
| 161 | on_checkpoint=on_checkpoint, | 163 | on_checkpoint=on_checkpoint, |
| 162 | on_sample=on_sample, | 164 | on_sample=on_sample, |
| 163 | ) | 165 | ) |
| 166 | |||
| 167 | |||
| 168 | textual_inversion_strategy = TrainingStrategy( | ||
| 169 | callbacks=textual_inversion_strategy_callbacks, | ||
| 170 | ) | ||
