diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-01 17:33:00 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-01 17:33:00 +0200 | 
| commit | 86e908656bcd7585ec45cd930176800f759f146a (patch) | |
| tree | 1169e9b1728e4c6fc8b70e46a37080ae0794ada8 /training | |
| parent | Experimental: TI via LoRA (diff) | |
| download | textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.gz textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.bz2 textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.zip | |
Combined TI with embedding and LoRA
Diffstat (limited to 'training')
| -rw-r--r-- | training/strategy/ti.py | 76 | 
1 files changed, 18 insertions, 58 deletions
| diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 19b8d25..33f5fb9 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional | 
| 2 | from functools import partial | 2 | from functools import partial | 
| 3 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager | 
| 4 | from pathlib import Path | 4 | from pathlib import Path | 
| 5 | 5 | ||
| 6 | import torch | 6 | import torch | 
| @@ -13,7 +13,6 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 13 | from slugify import slugify | 13 | from slugify import slugify | 
| 14 | 14 | ||
| 15 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer | 
| 16 | from training.util import EMAModel | ||
| 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 
| 18 | 17 | ||
| 19 | 18 | ||
| @@ -32,10 +31,6 @@ def textual_inversion_strategy_callbacks( | |||
| 32 | placeholder_tokens: list[str], | 31 | placeholder_tokens: list[str], | 
| 33 | placeholder_token_ids: list[list[int]], | 32 | placeholder_token_ids: list[list[int]], | 
| 34 | gradient_checkpointing: bool = False, | 33 | gradient_checkpointing: bool = False, | 
| 35 | use_ema: bool = False, | ||
| 36 | ema_inv_gamma: float = 1.0, | ||
| 37 | ema_power: int = 1, | ||
| 38 | ema_max_decay: float = 0.9999, | ||
| 39 | sample_batch_size: int = 1, | 34 | sample_batch_size: int = 1, | 
| 40 | sample_num_batches: int = 1, | 35 | sample_num_batches: int = 1, | 
| 41 | sample_num_steps: int = 20, | 36 | sample_num_steps: int = 20, | 
| @@ -68,25 +63,6 @@ def textual_inversion_strategy_callbacks( | |||
| 68 | image_size=sample_image_size, | 63 | image_size=sample_image_size, | 
| 69 | ) | 64 | ) | 
| 70 | 65 | ||
| 71 | if use_ema: | ||
| 72 | ema_embeddings = EMAModel( | ||
| 73 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
| 74 | inv_gamma=ema_inv_gamma, | ||
| 75 | power=ema_power, | ||
| 76 | max_value=ema_max_decay, | ||
| 77 | ) | ||
| 78 | ema_embeddings.to(accelerator.device) | ||
| 79 | else: | ||
| 80 | ema_embeddings = None | ||
| 81 | |||
| 82 | def ema_context(): | ||
| 83 | if ema_embeddings is not None: | ||
| 84 | return ema_embeddings.apply_temporary( | ||
| 85 | text_encoder.text_model.embeddings.overlay.parameters() | ||
| 86 | ) | ||
| 87 | else: | ||
| 88 | return nullcontext() | ||
| 89 | |||
| 90 | def on_accum_model(): | 66 | def on_accum_model(): | 
| 91 | return text_encoder.text_model.embeddings.overlay | 67 | return text_encoder.text_model.embeddings.overlay | 
| 92 | 68 | ||
| @@ -98,50 +74,36 @@ def textual_inversion_strategy_callbacks( | |||
| 98 | @contextmanager | 74 | @contextmanager | 
| 99 | def on_eval(): | 75 | def on_eval(): | 
| 100 | tokenizer.eval() | 76 | tokenizer.eval() | 
| 101 | 77 | yield | |
| 102 | with ema_context(): | ||
| 103 | yield | ||
| 104 | |||
| 105 | @torch.no_grad() | ||
| 106 | def on_after_optimize(zero_ids, lr: float): | ||
| 107 | if ema_embeddings is not None: | ||
| 108 | ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) | ||
| 109 | |||
| 110 | def on_log(): | ||
| 111 | if ema_embeddings is not None: | ||
| 112 | return {"ema_decay": ema_embeddings.decay} | ||
| 113 | return {} | ||
| 114 | 78 | ||
| 115 | @torch.no_grad() | 79 | @torch.no_grad() | 
| 116 | def on_checkpoint(step, postfix): | 80 | def on_checkpoint(step, postfix): | 
| 117 | print(f"Saving checkpoint for step {step}...") | 81 | print(f"Saving checkpoint for step {step}...") | 
| 118 | 82 | ||
| 119 | with ema_context(): | 83 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 
| 120 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 84 | text_encoder.text_model.embeddings.save_embed( | 
| 121 | text_encoder.text_model.embeddings.save_embed( | 85 | ids, | 
| 122 | ids, | 86 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 
| 123 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 87 | ) | 
| 124 | ) | ||
| 125 | 88 | ||
| 126 | @torch.no_grad() | 89 | @torch.no_grad() | 
| 127 | def on_sample(step): | 90 | def on_sample(step): | 
| 128 | with ema_context(): | 91 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 
| 129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 92 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 
| 130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | ||
| 131 | 93 | ||
| 132 | orig_unet_dtype = unet_.dtype | 94 | orig_unet_dtype = unet_.dtype | 
| 133 | orig_text_encoder_dtype = text_encoder_.dtype | 95 | orig_text_encoder_dtype = text_encoder_.dtype | 
| 134 | 96 | ||
| 135 | unet_.to(dtype=weight_dtype) | 97 | unet_.to(dtype=weight_dtype) | 
| 136 | text_encoder_.to(dtype=weight_dtype) | 98 | text_encoder_.to(dtype=weight_dtype) | 
| 137 | 99 | ||
| 138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 100 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 
| 139 | 101 | ||
| 140 | unet_.to(dtype=orig_unet_dtype) | 102 | unet_.to(dtype=orig_unet_dtype) | 
| 141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 103 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 
| 142 | 104 | ||
| 143 | del unet_ | 105 | del unet_ | 
| 144 | del text_encoder_ | 106 | del text_encoder_ | 
| 145 | 107 | ||
| 146 | if torch.cuda.is_available(): | 108 | if torch.cuda.is_available(): | 
| 147 | torch.cuda.empty_cache() | 109 | torch.cuda.empty_cache() | 
| @@ -150,8 +112,6 @@ def textual_inversion_strategy_callbacks( | |||
| 150 | on_accum_model=on_accum_model, | 112 | on_accum_model=on_accum_model, | 
| 151 | on_train=on_train, | 113 | on_train=on_train, | 
| 152 | on_eval=on_eval, | 114 | on_eval=on_eval, | 
| 153 | on_after_optimize=on_after_optimize, | ||
| 154 | on_log=on_log, | ||
| 155 | on_checkpoint=on_checkpoint, | 115 | on_checkpoint=on_checkpoint, | 
| 156 | on_sample=on_sample, | 116 | on_sample=on_sample, | 
| 157 | ) | 117 | ) | 
