diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 11 | ||||
-rw-r--r-- | training/strategy/lora.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 9 |
3 files changed, 15 insertions, 9 deletions
diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -66,7 +66,12 @@ class TrainingStrategy(): | |||
66 | prepare: TrainingStrategyPrepareCallable | 66 | prepare: TrainingStrategyPrepareCallable |
67 | 67 | ||
68 | 68 | ||
69 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | 69 | def get_models( |
70 | pretrained_model_name_or_path: str, | ||
71 | emb_r: int = 8, | ||
72 | emb_lora_alpha: int = 8, | ||
73 | emb_lora_dropout: float = 0.0 | ||
74 | ): | ||
70 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
71 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
72 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 77 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | |||
75 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
76 | pretrained_model_name_or_path, subfolder='scheduler') | 81 | pretrained_model_name_or_path, subfolder='scheduler') |
77 | 82 | ||
78 | embeddings = patch_managed_embeddings(text_encoder, emb_dropout) | 83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) |
79 | 84 | ||
80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
81 | 86 | ||
@@ -653,6 +658,8 @@ def train_loop( | |||
653 | on_checkpoint(global_step, "end") | 658 | on_checkpoint(global_step, "end") |
654 | raise KeyboardInterrupt | 659 | raise KeyboardInterrupt |
655 | 660 | ||
661 | return avg_loss, avg_acc, avg_loss_val, avg_acc_val | ||
662 | |||
656 | 663 | ||
657 | def train( | 664 | def train( |
658 | accelerator: Accelerator, | 665 | accelerator: Accelerator, |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -93,7 +93,7 @@ def lora_strategy_callbacks( | |||
93 | if use_emb_decay: | 93 | if use_emb_decay: |
94 | params = [ | 94 | params = [ |
95 | p | 95 | p |
96 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 96 | for p in text_encoder.text_model.embeddings.parameters() |
97 | if p.grad is not None | 97 | if p.grad is not None |
98 | ] | 98 | ] |
99 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
@@ -180,7 +180,7 @@ def lora_prepare( | |||
180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
182 | 182 | ||
183 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | 183 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
184 | 184 | ||
185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
186 | 186 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ca7cc3d..49236c6 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( | |||
72 | 72 | ||
73 | if use_ema: | 73 | if use_ema: |
74 | ema_embeddings = EMAModel( | 74 | ema_embeddings = EMAModel( |
75 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 75 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
76 | inv_gamma=ema_inv_gamma, | 76 | inv_gamma=ema_inv_gamma, |
77 | power=ema_power, | 77 | power=ema_power, |
78 | max_value=ema_max_decay, | 78 | max_value=ema_max_decay, |
@@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks( | |||
84 | def ema_context(): | 84 | def ema_context(): |
85 | if ema_embeddings is not None: | 85 | if ema_embeddings is not None: |
86 | return ema_embeddings.apply_temporary( | 86 | return ema_embeddings.apply_temporary( |
87 | text_encoder.text_model.embeddings.token_override_embedding.parameters() | 87 | text_encoder.text_model.embeddings.token_embedding.parameters() |
88 | ) | 88 | ) |
89 | else: | 89 | else: |
90 | return nullcontext() | 90 | return nullcontext() |
@@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( | |||
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |
111 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() |
112 | if p.grad is not None | 112 | if p.grad is not None |
113 | ] | 113 | ] |
114 | return torch.stack(params) if len(params) != 0 else None | 114 | return torch.stack(params) if len(params) != 0 else None |
@@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
120 | 120 | ||
121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
@@ -203,7 +203,6 @@ def textual_inversion_prepare( | |||
203 | text_encoder.text_model.encoder.requires_grad_(False) | 203 | text_encoder.text_model.encoder.requires_grad_(False) |
204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
206 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
207 | 206 | ||
208 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 207 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
209 | 208 | ||