summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
commitf00877a13bce50b02cfc3790f2d18a325e9ff95b (patch)
treeebbda04024081e9c3c00400fae98124f3db2cc9c /training/util.py
parentUpdate (diff)
downloadtextual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip
Update
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py112
1 files changed, 0 insertions, 112 deletions
diff --git a/training/util.py b/training/util.py
index a292edd..f46cc61 100644
--- a/training/util.py
+++ b/training/util.py
@@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15 15
16 16
17class TrainingStrategy():
18 @property
19 def main_model(self) -> torch.nn.Module:
20 ...
21
22 @contextmanager
23 def on_train(self, epoch: int):
24 yield
25
26 @contextmanager
27 def on_eval(self):
28 yield
29
30 def on_before_optimize(self, epoch: int):
31 ...
32
33 def on_after_optimize(self, lr: float):
34 ...
35
36 def on_log():
37 return {}
38
39
40def save_args(basepath: Path, args, extra={}): 17def save_args(basepath: Path, args, extra={}):
41 info = {"args": vars(args)} 18 info = {"args": vars(args)}
42 info["args"].update(extra) 19 info["args"].update(extra)
@@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}):
44 json.dump(info, f, indent=4) 21 json.dump(info, f, indent=4)
45 22
46 23
47def generate_class_images(
48 accelerator,
49 text_encoder,
50 vae,
51 unet,
52 tokenizer,
53 scheduler,
54 data_train,
55 sample_batch_size,
56 sample_image_size,
57 sample_steps
58):
59 missing_data = [item for item in data_train if not item.class_image_path.exists()]
60
61 if len(missing_data) == 0:
62 return
63
64 batched_data = [
65 missing_data[i:i+sample_batch_size]
66 for i in range(0, len(missing_data), sample_batch_size)
67 ]
68
69 pipeline = VlpnStableDiffusion(
70 text_encoder=text_encoder,
71 vae=vae,
72 unet=unet,
73 tokenizer=tokenizer,
74 scheduler=scheduler,
75 ).to(accelerator.device)
76 pipeline.set_progress_bar_config(dynamic_ncols=True)
77
78 with torch.inference_mode():
79 for batch in batched_data:
80 image_name = [item.class_image_path for item in batch]
81 prompt = [item.cprompt for item in batch]
82 nprompt = [item.nprompt for item in batch]
83
84 images = pipeline(
85 prompt=prompt,
86 negative_prompt=nprompt,
87 height=sample_image_size,
88 width=sample_image_size,
89 num_inference_steps=sample_steps
90 ).images
91
92 for i, image in enumerate(images):
93 image.save(image_name[i])
94
95 del pipeline
96
97 if torch.cuda.is_available():
98 torch.cuda.empty_cache()
99
100
101def get_models(pretrained_model_name_or_path: str):
102 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
103 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
104 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
105 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
106 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
107 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
108 pretrained_model_name_or_path, subfolder='scheduler')
109
110 embeddings = patch_managed_embeddings(text_encoder)
111
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113
114
115def add_placeholder_tokens(
116 tokenizer: MultiCLIPTokenizer,
117 embeddings: ManagedCLIPTextEmbeddings,
118 placeholder_tokens: list[str],
119 initializer_tokens: list[str],
120 num_vectors: Union[list[int], int]
121):
122 initializer_token_ids = [
123 tokenizer.encode(token, add_special_tokens=False)
124 for token in initializer_tokens
125 ]
126 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
127
128 embeddings.resize(len(tokenizer))
129
130 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
131 embeddings.add_embed(placeholder_token_id, initializer_token_id)
132
133 return placeholder_token_ids, initializer_token_ids
134
135
136class AverageMeter: 24class AverageMeter:
137 def __init__(self, name=None): 25 def __init__(self, name=None):
138 self.name = name 26 self.name = name