summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py12
1 files changed, 3 insertions, 9 deletions
diff --git a/training/functional.py b/training/functional.py
index 49c21c7..56c2995 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20 20
21from data.csv import VlpnDataset 21from data.csv import VlpnDataset
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 23from models.clip.embeddings import ManagedCLIPTextEmbeddings
24from models.clip.util import get_extended_embeddings 24from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator 26from models.convnext.discriminator import ConvNeXtDiscriminator
@@ -68,11 +68,7 @@ class TrainingStrategy():
68 prepare: TrainingStrategyPrepareCallable 68 prepare: TrainingStrategyPrepareCallable
69 69
70 70
71def get_models( 71def get_models(pretrained_model_name_or_path: str):
72 pretrained_model_name_or_path: str,
73 emb_alpha: int = 8,
74 emb_dropout: float = 0.0
75):
76 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
77 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
78 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -81,9 +77,7 @@ def get_models(
81 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
82 pretrained_model_name_or_path, subfolder='scheduler') 78 pretrained_model_name_or_path, subfolder='scheduler')
83 79
84 embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) 80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
85
86 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
87 81
88 82
89def save_samples( 83def save_samples(