diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
commit | f00877a13bce50b02cfc3790f2d18a325e9ff95b (patch) | |
tree | ebbda04024081e9c3c00400fae98124f3db2cc9c /training | |
parent | Update (diff) | |
download | textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2 textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip |
Update
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 34 | ||||
-rw-r--r-- | training/util.py | 112 |
2 files changed, 24 insertions, 122 deletions
diff --git a/training/functional.py b/training/functional.py index c100ea2..c5b514a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -25,17 +25,31 @@ def const(result=None): | |||
25 | return fn | 25 | return fn |
26 | 26 | ||
27 | 27 | ||
28 | def get_models(pretrained_model_name_or_path: str): | ||
29 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
30 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
31 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
32 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
33 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
34 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
35 | pretrained_model_name_or_path, subfolder='scheduler') | ||
36 | |||
37 | embeddings = patch_managed_embeddings(text_encoder) | ||
38 | |||
39 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
40 | |||
41 | |||
28 | def generate_class_images( | 42 | def generate_class_images( |
29 | accelerator, | 43 | accelerator: Accelerator, |
30 | text_encoder, | 44 | text_encoder: CLIPTextModel, |
31 | vae, | 45 | vae: AutoencoderKL, |
32 | unet, | 46 | unet: UNet2DConditionModel, |
33 | tokenizer, | 47 | tokenizer: MultiCLIPTokenizer, |
34 | scheduler, | 48 | sample_scheduler: DPMSolverMultistepScheduler, |
35 | data_train, | 49 | data_train, |
36 | sample_batch_size, | 50 | sample_batch_size: int, |
37 | sample_image_size, | 51 | sample_image_size: int, |
38 | sample_steps | 52 | sample_steps: int |
39 | ): | 53 | ): |
40 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 54 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
41 | 55 | ||
@@ -52,7 +66,7 @@ def generate_class_images( | |||
52 | vae=vae, | 66 | vae=vae, |
53 | unet=unet, | 67 | unet=unet, |
54 | tokenizer=tokenizer, | 68 | tokenizer=tokenizer, |
55 | scheduler=scheduler, | 69 | scheduler=sample_scheduler, |
56 | ).to(accelerator.device) | 70 | ).to(accelerator.device) |
57 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 71 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
58 | 72 | ||
diff --git a/training/util.py b/training/util.py index a292edd..f46cc61 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer | |||
14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
15 | 15 | ||
16 | 16 | ||
17 | class TrainingStrategy(): | ||
18 | @property | ||
19 | def main_model(self) -> torch.nn.Module: | ||
20 | ... | ||
21 | |||
22 | @contextmanager | ||
23 | def on_train(self, epoch: int): | ||
24 | yield | ||
25 | |||
26 | @contextmanager | ||
27 | def on_eval(self): | ||
28 | yield | ||
29 | |||
30 | def on_before_optimize(self, epoch: int): | ||
31 | ... | ||
32 | |||
33 | def on_after_optimize(self, lr: float): | ||
34 | ... | ||
35 | |||
36 | def on_log(): | ||
37 | return {} | ||
38 | |||
39 | |||
40 | def save_args(basepath: Path, args, extra={}): | 17 | def save_args(basepath: Path, args, extra={}): |
41 | info = {"args": vars(args)} | 18 | info = {"args": vars(args)} |
42 | info["args"].update(extra) | 19 | info["args"].update(extra) |
@@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}): | |||
44 | json.dump(info, f, indent=4) | 21 | json.dump(info, f, indent=4) |
45 | 22 | ||
46 | 23 | ||
47 | def generate_class_images( | ||
48 | accelerator, | ||
49 | text_encoder, | ||
50 | vae, | ||
51 | unet, | ||
52 | tokenizer, | ||
53 | scheduler, | ||
54 | data_train, | ||
55 | sample_batch_size, | ||
56 | sample_image_size, | ||
57 | sample_steps | ||
58 | ): | ||
59 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
60 | |||
61 | if len(missing_data) == 0: | ||
62 | return | ||
63 | |||
64 | batched_data = [ | ||
65 | missing_data[i:i+sample_batch_size] | ||
66 | for i in range(0, len(missing_data), sample_batch_size) | ||
67 | ] | ||
68 | |||
69 | pipeline = VlpnStableDiffusion( | ||
70 | text_encoder=text_encoder, | ||
71 | vae=vae, | ||
72 | unet=unet, | ||
73 | tokenizer=tokenizer, | ||
74 | scheduler=scheduler, | ||
75 | ).to(accelerator.device) | ||
76 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
77 | |||
78 | with torch.inference_mode(): | ||
79 | for batch in batched_data: | ||
80 | image_name = [item.class_image_path for item in batch] | ||
81 | prompt = [item.cprompt for item in batch] | ||
82 | nprompt = [item.nprompt for item in batch] | ||
83 | |||
84 | images = pipeline( | ||
85 | prompt=prompt, | ||
86 | negative_prompt=nprompt, | ||
87 | height=sample_image_size, | ||
88 | width=sample_image_size, | ||
89 | num_inference_steps=sample_steps | ||
90 | ).images | ||
91 | |||
92 | for i, image in enumerate(images): | ||
93 | image.save(image_name[i]) | ||
94 | |||
95 | del pipeline | ||
96 | |||
97 | if torch.cuda.is_available(): | ||
98 | torch.cuda.empty_cache() | ||
99 | |||
100 | |||
101 | def get_models(pretrained_model_name_or_path: str): | ||
102 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
103 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
104 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
105 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
106 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
107 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
108 | pretrained_model_name_or_path, subfolder='scheduler') | ||
109 | |||
110 | embeddings = patch_managed_embeddings(text_encoder) | ||
111 | |||
112 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
113 | |||
114 | |||
115 | def add_placeholder_tokens( | ||
116 | tokenizer: MultiCLIPTokenizer, | ||
117 | embeddings: ManagedCLIPTextEmbeddings, | ||
118 | placeholder_tokens: list[str], | ||
119 | initializer_tokens: list[str], | ||
120 | num_vectors: Union[list[int], int] | ||
121 | ): | ||
122 | initializer_token_ids = [ | ||
123 | tokenizer.encode(token, add_special_tokens=False) | ||
124 | for token in initializer_tokens | ||
125 | ] | ||
126 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
127 | |||
128 | embeddings.resize(len(tokenizer)) | ||
129 | |||
130 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | ||
131 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
132 | |||
133 | return placeholder_token_ids, initializer_token_ids | ||
134 | |||
135 | |||
136 | class AverageMeter: | 24 | class AverageMeter: |
137 | def __init__(self, name=None): | 25 | def __init__(self, name=None): |
138 | self.name = name | 26 | self.name = name |