summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
commitf00877a13bce50b02cfc3790f2d18a325e9ff95b (patch)
treeebbda04024081e9c3c00400fae98124f3db2cc9c /training/functional.py
parentUpdate (diff)
downloadtextual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py34
1 files changed, 24 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py
index c100ea2..c5b514a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -25,17 +25,31 @@ def const(result=None):
25 return fn 25 return fn
26 26
27 27
28def get_models(pretrained_model_name_or_path: str):
29 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
30 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
31 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
32 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
33 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
34 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
35 pretrained_model_name_or_path, subfolder='scheduler')
36
37 embeddings = patch_managed_embeddings(text_encoder)
38
39 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
40
41
28def generate_class_images( 42def generate_class_images(
29 accelerator, 43 accelerator: Accelerator,
30 text_encoder, 44 text_encoder: CLIPTextModel,
31 vae, 45 vae: AutoencoderKL,
32 unet, 46 unet: UNet2DConditionModel,
33 tokenizer, 47 tokenizer: MultiCLIPTokenizer,
34 scheduler, 48 sample_scheduler: DPMSolverMultistepScheduler,
35 data_train, 49 data_train,
36 sample_batch_size, 50 sample_batch_size: int,
37 sample_image_size, 51 sample_image_size: int,
38 sample_steps 52 sample_steps: int
39): 53):
40 missing_data = [item for item in data_train if not item.class_image_path.exists()] 54 missing_data = [item for item in data_train if not item.class_image_path.exists()]
41 55
@@ -52,7 +66,7 @@ def generate_class_images(
52 vae=vae, 66 vae=vae,
53 unet=unet, 67 unet=unet,
54 tokenizer=tokenizer, 68 tokenizer=tokenizer,
55 scheduler=scheduler, 69 scheduler=sample_scheduler,
56 ).to(accelerator.device) 70 ).to(accelerator.device)
57 pipeline.set_progress_bar_config(dynamic_ncols=True) 71 pipeline.set_progress_bar_config(dynamic_ncols=True)
58 72