summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py12
-rw-r--r--training/strategy/lora.py4
2 files changed, 5 insertions, 11 deletions
diff --git a/training/functional.py b/training/functional.py
index 49c21c7..56c2995 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20 20
21from data.csv import VlpnDataset 21from data.csv import VlpnDataset
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 23from models.clip.embeddings import ManagedCLIPTextEmbeddings
24from models.clip.util import get_extended_embeddings 24from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator 26from models.convnext.discriminator import ConvNeXtDiscriminator
@@ -68,11 +68,7 @@ class TrainingStrategy():
68 prepare: TrainingStrategyPrepareCallable 68 prepare: TrainingStrategyPrepareCallable
69 69
70 70
71def get_models( 71def get_models(pretrained_model_name_or_path: str):
72 pretrained_model_name_or_path: str,
73 emb_alpha: int = 8,
74 emb_dropout: float = 0.0
75):
76 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
77 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
78 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -81,9 +77,7 @@ def get_models(
81 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
82 pretrained_model_name_or_path, subfolder='scheduler') 78 pretrained_model_name_or_path, subfolder='scheduler')
83 79
84 embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) 80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
85
86 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
87 81
88 82
89def save_samples( 83def save_samples(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 0c0f633..f942b76 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -92,7 +92,7 @@ def lora_strategy_callbacks(
92 max_grad_norm 92 max_grad_norm
93 ) 93 )
94 94
95 if use_emb_decay: 95 if len(placeholder_tokens) != 0 and use_emb_decay:
96 params = [ 96 params = [
97 p 97 p
98 for p in text_encoder.text_model.embeddings.parameters() 98 for p in text_encoder.text_model.embeddings.parameters()
@@ -102,7 +102,7 @@ def lora_strategy_callbacks(
102 102
103 @torch.no_grad() 103 @torch.no_grad()
104 def on_after_optimize(w, lrs: dict[str, float]): 104 def on_after_optimize(w, lrs: dict[str, float]):
105 if use_emb_decay and w is not None and "emb" in lrs: 105 if w is not None and "emb" in lrs:
106 lr = lrs["emb"] 106 lr = lrs["emb"]
107 lambda_ = emb_decay * lr 107 lambda_ = emb_decay * lr
108 108