diff options
-rw-r--r-- | train_ti.py | 10 | ||||
-rw-r--r-- | trainer_old/base.py (renamed from trainer/base.py) | 0 | ||||
-rw-r--r-- | trainer_old/dreambooth.py (renamed from trainer/dreambooth.py) | 0 | ||||
-rw-r--r-- | trainer_old/ti.py (renamed from trainer/ti.py) | 4 | ||||
-rw-r--r-- | training/functional.py | 34 | ||||
-rw-r--r-- | training/util.py | 112 |
6 files changed, 33 insertions, 127 deletions
diff --git a/train_ti.py b/train_ti.py index a4e2dde..78c1b5c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -11,20 +11,16 @@ import torch.utils.checkpoint | |||
11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from diffusers import AutoencoderKL, UNet2DConditionModel | ||
15 | import matplotlib.pyplot as plt | 14 | import matplotlib.pyplot as plt |
16 | from transformers import CLIPTextModel | ||
17 | from slugify import slugify | 15 | from slugify import slugify |
18 | 16 | ||
19 | from util import load_config, load_embeddings_from_dir | 17 | from util import load_config, load_embeddings_from_dir |
20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
21 | from data.csv import VlpnDataModule, VlpnDataItem | 18 | from data.csv import VlpnDataModule, VlpnDataItem |
22 | from trainer.base import Checkpointer | 19 | from trainer_old.base import Checkpointer |
23 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | 20 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models |
24 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
25 | from training.lr import LRFinder | 22 | from training.lr import LRFinder |
26 | from training.util import EMAModel, save_args | 23 | from training.util import EMAModel, save_args |
27 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
28 | 24 | ||
29 | logger = get_logger(__name__) | 25 | logger = get_logger(__name__) |
30 | 26 | ||
@@ -485,12 +481,16 @@ class TextualInversionCheckpointer(Checkpointer): | |||
485 | def __init__( | 481 | def __init__( |
486 | self, | 482 | self, |
487 | ema_embeddings: EMAModel, | 483 | ema_embeddings: EMAModel, |
484 | placeholder_tokens: list[str], | ||
485 | placeholder_token_ids: list[list[int]], | ||
488 | *args, | 486 | *args, |
489 | **kwargs, | 487 | **kwargs, |
490 | ): | 488 | ): |
491 | super().__init__(*args, **kwargs) | 489 | super().__init__(*args, **kwargs) |
492 | 490 | ||
493 | self.ema_embeddings = ema_embeddings | 491 | self.ema_embeddings = ema_embeddings |
492 | self.placeholder_tokens = placeholder_tokens | ||
493 | self.placeholder_token_ids = placeholder_token_ids | ||
494 | 494 | ||
495 | @torch.no_grad() | 495 | @torch.no_grad() |
496 | def checkpoint(self, step, postfix): | 496 | def checkpoint(self, step, postfix): |
diff --git a/trainer/base.py b/trainer_old/base.py index 1f85e71..1f85e71 100644 --- a/trainer/base.py +++ b/trainer_old/base.py | |||
diff --git a/trainer/dreambooth.py b/trainer_old/dreambooth.py index e69de29..e69de29 100644 --- a/trainer/dreambooth.py +++ b/trainer_old/dreambooth.py | |||
diff --git a/trainer/ti.py b/trainer_old/ti.py index 388acd3..66393af 100644 --- a/trainer/ti.py +++ b/trainer_old/ti.py | |||
@@ -15,12 +15,16 @@ class TextualInversionCheckpointer(Checkpointer): | |||
15 | def __init__( | 15 | def __init__( |
16 | self, | 16 | self, |
17 | ema_embeddings: EMAModel, | 17 | ema_embeddings: EMAModel, |
18 | placeholder_tokens: list[str], | ||
19 | placeholder_token_ids: list[list[int]], | ||
18 | *args, | 20 | *args, |
19 | **kwargs, | 21 | **kwargs, |
20 | ): | 22 | ): |
21 | super().__init__(*args, **kwargs) | 23 | super().__init__(*args, **kwargs) |
22 | 24 | ||
23 | self.ema_embeddings = ema_embeddings | 25 | self.ema_embeddings = ema_embeddings |
26 | self.placeholder_tokens = placeholder_tokens | ||
27 | self.placeholder_token_ids = placeholder_token_ids | ||
24 | 28 | ||
25 | @torch.no_grad() | 29 | @torch.no_grad() |
26 | def checkpoint(self, step, postfix): | 30 | def checkpoint(self, step, postfix): |
diff --git a/training/functional.py b/training/functional.py index c100ea2..c5b514a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -25,17 +25,31 @@ def const(result=None): | |||
25 | return fn | 25 | return fn |
26 | 26 | ||
27 | 27 | ||
28 | def get_models(pretrained_model_name_or_path: str): | ||
29 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
30 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
31 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
32 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
33 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
34 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
35 | pretrained_model_name_or_path, subfolder='scheduler') | ||
36 | |||
37 | embeddings = patch_managed_embeddings(text_encoder) | ||
38 | |||
39 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
40 | |||
41 | |||
28 | def generate_class_images( | 42 | def generate_class_images( |
29 | accelerator, | 43 | accelerator: Accelerator, |
30 | text_encoder, | 44 | text_encoder: CLIPTextModel, |
31 | vae, | 45 | vae: AutoencoderKL, |
32 | unet, | 46 | unet: UNet2DConditionModel, |
33 | tokenizer, | 47 | tokenizer: MultiCLIPTokenizer, |
34 | scheduler, | 48 | sample_scheduler: DPMSolverMultistepScheduler, |
35 | data_train, | 49 | data_train, |
36 | sample_batch_size, | 50 | sample_batch_size: int, |
37 | sample_image_size, | 51 | sample_image_size: int, |
38 | sample_steps | 52 | sample_steps: int |
39 | ): | 53 | ): |
40 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 54 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
41 | 55 | ||
@@ -52,7 +66,7 @@ def generate_class_images( | |||
52 | vae=vae, | 66 | vae=vae, |
53 | unet=unet, | 67 | unet=unet, |
54 | tokenizer=tokenizer, | 68 | tokenizer=tokenizer, |
55 | scheduler=scheduler, | 69 | scheduler=sample_scheduler, |
56 | ).to(accelerator.device) | 70 | ).to(accelerator.device) |
57 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 71 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
58 | 72 | ||
diff --git a/training/util.py b/training/util.py index a292edd..f46cc61 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer | |||
14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
15 | 15 | ||
16 | 16 | ||
17 | class TrainingStrategy(): | ||
18 | @property | ||
19 | def main_model(self) -> torch.nn.Module: | ||
20 | ... | ||
21 | |||
22 | @contextmanager | ||
23 | def on_train(self, epoch: int): | ||
24 | yield | ||
25 | |||
26 | @contextmanager | ||
27 | def on_eval(self): | ||
28 | yield | ||
29 | |||
30 | def on_before_optimize(self, epoch: int): | ||
31 | ... | ||
32 | |||
33 | def on_after_optimize(self, lr: float): | ||
34 | ... | ||
35 | |||
36 | def on_log(): | ||
37 | return {} | ||
38 | |||
39 | |||
40 | def save_args(basepath: Path, args, extra={}): | 17 | def save_args(basepath: Path, args, extra={}): |
41 | info = {"args": vars(args)} | 18 | info = {"args": vars(args)} |
42 | info["args"].update(extra) | 19 | info["args"].update(extra) |
@@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}): | |||
44 | json.dump(info, f, indent=4) | 21 | json.dump(info, f, indent=4) |
45 | 22 | ||
46 | 23 | ||
47 | def generate_class_images( | ||
48 | accelerator, | ||
49 | text_encoder, | ||
50 | vae, | ||
51 | unet, | ||
52 | tokenizer, | ||
53 | scheduler, | ||
54 | data_train, | ||
55 | sample_batch_size, | ||
56 | sample_image_size, | ||
57 | sample_steps | ||
58 | ): | ||
59 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
60 | |||
61 | if len(missing_data) == 0: | ||
62 | return | ||
63 | |||
64 | batched_data = [ | ||
65 | missing_data[i:i+sample_batch_size] | ||
66 | for i in range(0, len(missing_data), sample_batch_size) | ||
67 | ] | ||
68 | |||
69 | pipeline = VlpnStableDiffusion( | ||
70 | text_encoder=text_encoder, | ||
71 | vae=vae, | ||
72 | unet=unet, | ||
73 | tokenizer=tokenizer, | ||
74 | scheduler=scheduler, | ||
75 | ).to(accelerator.device) | ||
76 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
77 | |||
78 | with torch.inference_mode(): | ||
79 | for batch in batched_data: | ||
80 | image_name = [item.class_image_path for item in batch] | ||
81 | prompt = [item.cprompt for item in batch] | ||
82 | nprompt = [item.nprompt for item in batch] | ||
83 | |||
84 | images = pipeline( | ||
85 | prompt=prompt, | ||
86 | negative_prompt=nprompt, | ||
87 | height=sample_image_size, | ||
88 | width=sample_image_size, | ||
89 | num_inference_steps=sample_steps | ||
90 | ).images | ||
91 | |||
92 | for i, image in enumerate(images): | ||
93 | image.save(image_name[i]) | ||
94 | |||
95 | del pipeline | ||
96 | |||
97 | if torch.cuda.is_available(): | ||
98 | torch.cuda.empty_cache() | ||
99 | |||
100 | |||
101 | def get_models(pretrained_model_name_or_path: str): | ||
102 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
103 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
104 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
105 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
106 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
107 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
108 | pretrained_model_name_or_path, subfolder='scheduler') | ||
109 | |||
110 | embeddings = patch_managed_embeddings(text_encoder) | ||
111 | |||
112 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
113 | |||
114 | |||
115 | def add_placeholder_tokens( | ||
116 | tokenizer: MultiCLIPTokenizer, | ||
117 | embeddings: ManagedCLIPTextEmbeddings, | ||
118 | placeholder_tokens: list[str], | ||
119 | initializer_tokens: list[str], | ||
120 | num_vectors: Union[list[int], int] | ||
121 | ): | ||
122 | initializer_token_ids = [ | ||
123 | tokenizer.encode(token, add_special_tokens=False) | ||
124 | for token in initializer_tokens | ||
125 | ] | ||
126 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
127 | |||
128 | embeddings.resize(len(tokenizer)) | ||
129 | |||
130 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | ||
131 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
132 | |||
133 | return placeholder_token_ids, initializer_token_ids | ||
134 | |||
135 | |||
136 | class AverageMeter: | 24 | class AverageMeter: |
137 | def __init__(self, name=None): | 25 | def __init__(self, name=None): |
138 | self.name = name | 26 | self.name = name |