From 1abbfd5215a99dba9d699e91baec00e6f02a0bd5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 08:13:39 +0100 Subject: Update --- data/csv.py | 2 +- models/clip/embeddings.py | 3 +++ training/strategy/dreambooth.py | 8 ++++---- training/strategy/ti.py | 8 ++++---- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/data/csv.py b/data/csv.py index 85b98f8..b4c81d7 100644 --- a/data/csv.py +++ b/data/csv.py @@ -282,7 +282,7 @@ class VlpnDataModule(): collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) if valid_set_size == 0: - data_train, data_val = items, items[:1] + data_train, data_val = items, items[:self.batch_size] else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..1cc59d9 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -100,6 +100,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds def normalize(self, target: float = 0.4, lambda_: float = 1.0): + if lambda_ == 0: + return + w = self.temp_token_embedding.weight pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) w[self.temp_token_ids] = F.normalize( diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index bc26ee6..d813b49 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -88,7 +88,7 @@ def dreambooth_strategy_callbacks( ema_unet = None def ema_context(): - if use_ema: + if ema_unet is not None: return ema_unet.apply_temporary(unet.parameters()) else: return nullcontext() @@ -102,7 +102,7 @@ def dreambooth_strategy_callbacks( text_encoder.text_model.embeddings.persist() text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) - if use_ema: + if ema_unet is not None: ema_unet.to(accelerator.device) @contextmanager @@ -134,11 +134,11 @@ def dreambooth_strategy_callbacks( @torch.no_grad() def on_after_optimize(lr: float): - if use_ema: + if ema_unet is not None: ema_unet.step(unet.parameters()) def on_log(): - if use_ema: + if ema_unet is not None: return {"ema_decay": ema_unet.decay} return {} diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 597abd0..081180f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -88,7 +88,7 @@ def textual_inversion_strategy_callbacks( ema_embeddings = None def ema_context(): - if use_ema: + if ema_embeddings is not None: return ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) @@ -101,7 +101,7 @@ def textual_inversion_strategy_callbacks( def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) - if use_ema: + if ema_embeddings is not None: ema_embeddings.to(accelerator.device) if gradient_checkpointing: @@ -120,7 +120,7 @@ def textual_inversion_strategy_callbacks( yield def on_after_optimize(lr: float): - if use_ema: + if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) @torch.no_grad() @@ -132,7 +132,7 @@ def textual_inversion_strategy_callbacks( ) def on_log(): - if use_ema: + if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} return {} -- cgit v1.2.3-70-g09d2