diff options
author | Volpeon <git@volpeon.ink> | 2023-04-03 18:52:30 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-03 18:52:30 +0200 |
commit | e68cb3542e08c9f22ce8a94fd88bebe0c121ca17 (patch) | |
tree | 87fbb9d92233aa1bb7342e31aec64d6d375f41e1 /training | |
parent | TI: No tag dropout by default (diff) | |
download | textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.gz textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.bz2 textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.zip |
TI: Delta learning
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 23 |
2 files changed, 2 insertions, 25 deletions
diff --git a/training/functional.py b/training/functional.py index 96ecbc1..1d8e2ee 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols): | |||
73 | return grid | 73 | return grid |
74 | 74 | ||
75 | 75 | ||
76 | def get_models(pretrained_model_name_or_path: str): | 76 | def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): |
77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
83 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
84 | 84 | ||
85 | embeddings = patch_managed_embeddings(text_encoder) | 85 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha) |
86 | 86 | ||
87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
88 | 88 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index c7520ed..16baa34 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -31,10 +31,6 @@ def textual_inversion_strategy_callbacks( | |||
31 | seed: int, | 31 | seed: int, |
32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
33 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
34 | gradient_checkpointing: bool = False, | ||
35 | use_emb_decay: bool = False, | ||
36 | emb_decay_target: float = 0.4, | ||
37 | emb_decay: float = 1e-2, | ||
38 | use_ema: bool = False, | 34 | use_ema: bool = False, |
39 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
40 | ema_power: int = 1, | 36 | ema_power: int = 1, |
@@ -106,28 +102,10 @@ def textual_inversion_strategy_callbacks( | |||
106 | yield | 102 | yield |
107 | 103 | ||
108 | @torch.no_grad() | 104 | @torch.no_grad() |
109 | def on_before_optimize(lr: float, epoch: int): | ||
110 | if use_emb_decay: | ||
111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
112 | return torch.all(w.grad == 0, dim=1) | ||
113 | |||
114 | @torch.no_grad() | ||
115 | def on_after_optimize(zero_ids, lr: float): | 105 | def on_after_optimize(zero_ids, lr: float): |
116 | if ema_embeddings is not None: | 106 | if ema_embeddings is not None: |
117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 107 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
118 | 108 | ||
119 | if use_emb_decay: | ||
120 | lambda_ = emb_decay * lr | ||
121 | |||
122 | if lambda_ != 0: | ||
123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
124 | |||
125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
126 | mask[zero_ids] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
130 | |||
131 | def on_log(): | 109 | def on_log(): |
132 | if ema_embeddings is not None: | 110 | if ema_embeddings is not None: |
133 | return {"ema_decay": ema_embeddings.decay} | 111 | return {"ema_decay": ema_embeddings.decay} |
@@ -171,7 +149,6 @@ def textual_inversion_strategy_callbacks( | |||
171 | on_accum_model=on_accum_model, | 149 | on_accum_model=on_accum_model, |
172 | on_train=on_train, | 150 | on_train=on_train, |
173 | on_eval=on_eval, | 151 | on_eval=on_eval, |
174 | on_before_optimize=on_before_optimize, | ||
175 | on_after_optimize=on_after_optimize, | 152 | on_after_optimize=on_after_optimize, |
176 | on_log=on_log, | 153 | on_log=on_log, |
177 | on_checkpoint=on_checkpoint, | 154 | on_checkpoint=on_checkpoint, |