summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 08:13:39 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 08:13:39 +0100
commit1abbfd5215a99dba9d699e91baec00e6f02a0bd5 (patch)
tree670e846b3c08bd8957955ea56d3a4c4b58a8ad6f
parentUpdate (diff)
downloadtextual-inversion-diff-1abbfd5215a99dba9d699e91baec00e6f02a0bd5.tar.gz
textual-inversion-diff-1abbfd5215a99dba9d699e91baec00e6f02a0bd5.tar.bz2
textual-inversion-diff-1abbfd5215a99dba9d699e91baec00e6f02a0bd5.zip
Update
-rw-r--r--data/csv.py2
-rw-r--r--models/clip/embeddings.py3
-rw-r--r--training/strategy/dreambooth.py8
-rw-r--r--training/strategy/ti.py8
4 files changed, 12 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py
index 85b98f8..b4c81d7 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -282,7 +282,7 @@ class VlpnDataModule():
282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
283 283
284 if valid_set_size == 0: 284 if valid_set_size == 0:
285 data_train, data_val = items, items[:1] 285 data_train, data_val = items, items[:self.batch_size]
286 else: 286 else:
287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
288 288
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9a23a2a..1cc59d9 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -100,6 +100,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
100 return embeds 100 return embeds
101 101
102 def normalize(self, target: float = 0.4, lambda_: float = 1.0): 102 def normalize(self, target: float = 0.4, lambda_: float = 1.0):
103 if lambda_ == 0:
104 return
105
103 w = self.temp_token_embedding.weight 106 w = self.temp_token_embedding.weight
104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 107 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
105 w[self.temp_token_ids] = F.normalize( 108 w[self.temp_token_ids] = F.normalize(
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index bc26ee6..d813b49 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -88,7 +88,7 @@ def dreambooth_strategy_callbacks(
88 ema_unet = None 88 ema_unet = None
89 89
90 def ema_context(): 90 def ema_context():
91 if use_ema: 91 if ema_unet is not None:
92 return ema_unet.apply_temporary(unet.parameters()) 92 return ema_unet.apply_temporary(unet.parameters())
93 else: 93 else:
94 return nullcontext() 94 return nullcontext()
@@ -102,7 +102,7 @@ def dreambooth_strategy_callbacks(
102 text_encoder.text_model.embeddings.persist() 102 text_encoder.text_model.embeddings.persist()
103 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) 103 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
104 104
105 if use_ema: 105 if ema_unet is not None:
106 ema_unet.to(accelerator.device) 106 ema_unet.to(accelerator.device)
107 107
108 @contextmanager 108 @contextmanager
@@ -134,11 +134,11 @@ def dreambooth_strategy_callbacks(
134 134
135 @torch.no_grad() 135 @torch.no_grad()
136 def on_after_optimize(lr: float): 136 def on_after_optimize(lr: float):
137 if use_ema: 137 if ema_unet is not None:
138 ema_unet.step(unet.parameters()) 138 ema_unet.step(unet.parameters())
139 139
140 def on_log(): 140 def on_log():
141 if use_ema: 141 if ema_unet is not None:
142 return {"ema_decay": ema_unet.decay} 142 return {"ema_decay": ema_unet.decay}
143 return {} 143 return {}
144 144
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 597abd0..081180f 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -88,7 +88,7 @@ def textual_inversion_strategy_callbacks(
88 ema_embeddings = None 88 ema_embeddings = None
89 89
90 def ema_context(): 90 def ema_context():
91 if use_ema: 91 if ema_embeddings is not None:
92 return ema_embeddings.apply_temporary( 92 return ema_embeddings.apply_temporary(
93 text_encoder.text_model.embeddings.temp_token_embedding.parameters() 93 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
94 ) 94 )
@@ -101,7 +101,7 @@ def textual_inversion_strategy_callbacks(
101 def on_prepare(): 101 def on_prepare():
102 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) 102 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
103 103
104 if use_ema: 104 if ema_embeddings is not None:
105 ema_embeddings.to(accelerator.device) 105 ema_embeddings.to(accelerator.device)
106 106
107 if gradient_checkpointing: 107 if gradient_checkpointing:
@@ -120,7 +120,7 @@ def textual_inversion_strategy_callbacks(
120 yield 120 yield
121 121
122 def on_after_optimize(lr: float): 122 def on_after_optimize(lr: float):
123 if use_ema: 123 if ema_embeddings is not None:
124 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 124 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
125 125
126 @torch.no_grad() 126 @torch.no_grad()
@@ -132,7 +132,7 @@ def textual_inversion_strategy_callbacks(
132 ) 132 )
133 133
134 def on_log(): 134 def on_log():
135 if use_ema: 135 if ema_embeddings is not None:
136 return {"ema_decay": ema_embeddings.decay} 136 return {"ema_decay": ema_embeddings.decay}
137 return {} 137 return {}
138 138