diff options
-rw-r--r-- | data/csv.py | 2 | ||||
-rw-r--r-- | models/clip/embeddings.py | 3 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 8 | ||||
-rw-r--r-- | training/strategy/ti.py | 8 |
4 files changed, 12 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py index 85b98f8..b4c81d7 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -282,7 +282,7 @@ class VlpnDataModule(): | |||
282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
283 | 283 | ||
284 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
285 | data_train, data_val = items, items[:1] | 285 | data_train, data_val = items, items[:self.batch_size] |
286 | else: | 286 | else: |
287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
288 | 288 | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..1cc59d9 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -100,6 +100,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
100 | return embeds | 100 | return embeds |
101 | 101 | ||
102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): | 102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): |
103 | if lambda_ == 0: | ||
104 | return | ||
105 | |||
103 | w = self.temp_token_embedding.weight | 106 | w = self.temp_token_embedding.weight |
104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 107 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
105 | w[self.temp_token_ids] = F.normalize( | 108 | w[self.temp_token_ids] = F.normalize( |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index bc26ee6..d813b49 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -88,7 +88,7 @@ def dreambooth_strategy_callbacks( | |||
88 | ema_unet = None | 88 | ema_unet = None |
89 | 89 | ||
90 | def ema_context(): | 90 | def ema_context(): |
91 | if use_ema: | 91 | if ema_unet is not None: |
92 | return ema_unet.apply_temporary(unet.parameters()) | 92 | return ema_unet.apply_temporary(unet.parameters()) |
93 | else: | 93 | else: |
94 | return nullcontext() | 94 | return nullcontext() |
@@ -102,7 +102,7 @@ def dreambooth_strategy_callbacks( | |||
102 | text_encoder.text_model.embeddings.persist() | 102 | text_encoder.text_model.embeddings.persist() |
103 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | 103 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) |
104 | 104 | ||
105 | if use_ema: | 105 | if ema_unet is not None: |
106 | ema_unet.to(accelerator.device) | 106 | ema_unet.to(accelerator.device) |
107 | 107 | ||
108 | @contextmanager | 108 | @contextmanager |
@@ -134,11 +134,11 @@ def dreambooth_strategy_callbacks( | |||
134 | 134 | ||
135 | @torch.no_grad() | 135 | @torch.no_grad() |
136 | def on_after_optimize(lr: float): | 136 | def on_after_optimize(lr: float): |
137 | if use_ema: | 137 | if ema_unet is not None: |
138 | ema_unet.step(unet.parameters()) | 138 | ema_unet.step(unet.parameters()) |
139 | 139 | ||
140 | def on_log(): | 140 | def on_log(): |
141 | if use_ema: | 141 | if ema_unet is not None: |
142 | return {"ema_decay": ema_unet.decay} | 142 | return {"ema_decay": ema_unet.decay} |
143 | return {} | 143 | return {} |
144 | 144 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 597abd0..081180f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -88,7 +88,7 @@ def textual_inversion_strategy_callbacks( | |||
88 | ema_embeddings = None | 88 | ema_embeddings = None |
89 | 89 | ||
90 | def ema_context(): | 90 | def ema_context(): |
91 | if use_ema: | 91 | if ema_embeddings is not None: |
92 | return ema_embeddings.apply_temporary( | 92 | return ema_embeddings.apply_temporary( |
93 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 93 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
94 | ) | 94 | ) |
@@ -101,7 +101,7 @@ def textual_inversion_strategy_callbacks( | |||
101 | def on_prepare(): | 101 | def on_prepare(): |
102 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | 102 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) |
103 | 103 | ||
104 | if use_ema: | 104 | if ema_embeddings is not None: |
105 | ema_embeddings.to(accelerator.device) | 105 | ema_embeddings.to(accelerator.device) |
106 | 106 | ||
107 | if gradient_checkpointing: | 107 | if gradient_checkpointing: |
@@ -120,7 +120,7 @@ def textual_inversion_strategy_callbacks( | |||
120 | yield | 120 | yield |
121 | 121 | ||
122 | def on_after_optimize(lr: float): | 122 | def on_after_optimize(lr: float): |
123 | if use_ema: | 123 | if ema_embeddings is not None: |
124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
125 | 125 | ||
126 | @torch.no_grad() | 126 | @torch.no_grad() |
@@ -132,7 +132,7 @@ def textual_inversion_strategy_callbacks( | |||
132 | ) | 132 | ) |
133 | 133 | ||
134 | def on_log(): | 134 | def on_log(): |
135 | if use_ema: | 135 | if ema_embeddings is not None: |
136 | return {"ema_decay": ema_embeddings.decay} | 136 | return {"ema_decay": ema_embeddings.decay} |
137 | return {} | 137 | return {} |
138 | 138 | ||