summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
commitf00877a13bce50b02cfc3790f2d18a325e9ff95b (patch)
treeebbda04024081e9c3c00400fae98124f3db2cc9c /training
parentUpdate (diff)
downloadtextual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip
Update
Diffstat (limited to 'training')
-rw-r--r--training/functional.py34
-rw-r--r--training/util.py112
2 files changed, 24 insertions, 122 deletions
diff --git a/training/functional.py b/training/functional.py
index c100ea2..c5b514a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -25,17 +25,31 @@ def const(result=None):
25 return fn 25 return fn
26 26
27 27
28def get_models(pretrained_model_name_or_path: str):
29 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
30 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
31 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
32 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
33 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
34 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
35 pretrained_model_name_or_path, subfolder='scheduler')
36
37 embeddings = patch_managed_embeddings(text_encoder)
38
39 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
40
41
28def generate_class_images( 42def generate_class_images(
29 accelerator, 43 accelerator: Accelerator,
30 text_encoder, 44 text_encoder: CLIPTextModel,
31 vae, 45 vae: AutoencoderKL,
32 unet, 46 unet: UNet2DConditionModel,
33 tokenizer, 47 tokenizer: MultiCLIPTokenizer,
34 scheduler, 48 sample_scheduler: DPMSolverMultistepScheduler,
35 data_train, 49 data_train,
36 sample_batch_size, 50 sample_batch_size: int,
37 sample_image_size, 51 sample_image_size: int,
38 sample_steps 52 sample_steps: int
39): 53):
40 missing_data = [item for item in data_train if not item.class_image_path.exists()] 54 missing_data = [item for item in data_train if not item.class_image_path.exists()]
41 55
@@ -52,7 +66,7 @@ def generate_class_images(
52 vae=vae, 66 vae=vae,
53 unet=unet, 67 unet=unet,
54 tokenizer=tokenizer, 68 tokenizer=tokenizer,
55 scheduler=scheduler, 69 scheduler=sample_scheduler,
56 ).to(accelerator.device) 70 ).to(accelerator.device)
57 pipeline.set_progress_bar_config(dynamic_ncols=True) 71 pipeline.set_progress_bar_config(dynamic_ncols=True)
58 72
diff --git a/training/util.py b/training/util.py
index a292edd..f46cc61 100644
--- a/training/util.py
+++ b/training/util.py
@@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15 15
16 16
17class TrainingStrategy():
18 @property
19 def main_model(self) -> torch.nn.Module:
20 ...
21
22 @contextmanager
23 def on_train(self, epoch: int):
24 yield
25
26 @contextmanager
27 def on_eval(self):
28 yield
29
30 def on_before_optimize(self, epoch: int):
31 ...
32
33 def on_after_optimize(self, lr: float):
34 ...
35
36 def on_log():
37 return {}
38
39
40def save_args(basepath: Path, args, extra={}): 17def save_args(basepath: Path, args, extra={}):
41 info = {"args": vars(args)} 18 info = {"args": vars(args)}
42 info["args"].update(extra) 19 info["args"].update(extra)
@@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}):
44 json.dump(info, f, indent=4) 21 json.dump(info, f, indent=4)
45 22
46 23
47def generate_class_images(
48 accelerator,
49 text_encoder,
50 vae,
51 unet,
52 tokenizer,
53 scheduler,
54 data_train,
55 sample_batch_size,
56 sample_image_size,
57 sample_steps
58):
59 missing_data = [item for item in data_train if not item.class_image_path.exists()]
60
61 if len(missing_data) == 0:
62 return
63
64 batched_data = [
65 missing_data[i:i+sample_batch_size]
66 for i in range(0, len(missing_data), sample_batch_size)
67 ]
68
69 pipeline = VlpnStableDiffusion(
70 text_encoder=text_encoder,
71 vae=vae,
72 unet=unet,
73 tokenizer=tokenizer,
74 scheduler=scheduler,
75 ).to(accelerator.device)
76 pipeline.set_progress_bar_config(dynamic_ncols=True)
77
78 with torch.inference_mode():
79 for batch in batched_data:
80 image_name = [item.class_image_path for item in batch]
81 prompt = [item.cprompt for item in batch]
82 nprompt = [item.nprompt for item in batch]
83
84 images = pipeline(
85 prompt=prompt,
86 negative_prompt=nprompt,
87 height=sample_image_size,
88 width=sample_image_size,
89 num_inference_steps=sample_steps
90 ).images
91
92 for i, image in enumerate(images):
93 image.save(image_name[i])
94
95 del pipeline
96
97 if torch.cuda.is_available():
98 torch.cuda.empty_cache()
99
100
101def get_models(pretrained_model_name_or_path: str):
102 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
103 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
104 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
105 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
106 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
107 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
108 pretrained_model_name_or_path, subfolder='scheduler')
109
110 embeddings = patch_managed_embeddings(text_encoder)
111
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113
114
115def add_placeholder_tokens(
116 tokenizer: MultiCLIPTokenizer,
117 embeddings: ManagedCLIPTextEmbeddings,
118 placeholder_tokens: list[str],
119 initializer_tokens: list[str],
120 num_vectors: Union[list[int], int]
121):
122 initializer_token_ids = [
123 tokenizer.encode(token, add_special_tokens=False)
124 for token in initializer_tokens
125 ]
126 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
127
128 embeddings.resize(len(tokenizer))
129
130 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
131 embeddings.add_embed(placeholder_token_id, initializer_token_id)
132
133 return placeholder_token_ids, initializer_token_ids
134
135
136class AverageMeter: 24class AverageMeter:
137 def __init__(self, name=None): 25 def __init__(self, name=None):
138 self.name = name 26 self.name = name