diff options
| -rw-r--r-- | train_ti.py | 10 | ||||
| -rw-r--r-- | trainer_old/base.py (renamed from trainer/base.py) | 0 | ||||
| -rw-r--r-- | trainer_old/dreambooth.py (renamed from trainer/dreambooth.py) | 0 | ||||
| -rw-r--r-- | trainer_old/ti.py (renamed from trainer/ti.py) | 4 | ||||
| -rw-r--r-- | training/functional.py | 34 | ||||
| -rw-r--r-- | training/util.py | 112 |
6 files changed, 33 insertions, 127 deletions
diff --git a/train_ti.py b/train_ti.py index a4e2dde..78c1b5c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -11,20 +11,16 @@ import torch.utils.checkpoint | |||
| 11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
| 12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from diffusers import AutoencoderKL, UNet2DConditionModel | ||
| 15 | import matplotlib.pyplot as plt | 14 | import matplotlib.pyplot as plt |
| 16 | from transformers import CLIPTextModel | ||
| 17 | from slugify import slugify | 15 | from slugify import slugify |
| 18 | 16 | ||
| 19 | from util import load_config, load_embeddings_from_dir | 17 | from util import load_config, load_embeddings_from_dir |
| 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 21 | from data.csv import VlpnDataModule, VlpnDataItem | 18 | from data.csv import VlpnDataModule, VlpnDataItem |
| 22 | from trainer.base import Checkpointer | 19 | from trainer_old.base import Checkpointer |
| 23 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | 20 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models |
| 24 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
| 25 | from training.lr import LRFinder | 22 | from training.lr import LRFinder |
| 26 | from training.util import EMAModel, save_args | 23 | from training.util import EMAModel, save_args |
| 27 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 28 | 24 | ||
| 29 | logger = get_logger(__name__) | 25 | logger = get_logger(__name__) |
| 30 | 26 | ||
| @@ -485,12 +481,16 @@ class TextualInversionCheckpointer(Checkpointer): | |||
| 485 | def __init__( | 481 | def __init__( |
| 486 | self, | 482 | self, |
| 487 | ema_embeddings: EMAModel, | 483 | ema_embeddings: EMAModel, |
| 484 | placeholder_tokens: list[str], | ||
| 485 | placeholder_token_ids: list[list[int]], | ||
| 488 | *args, | 486 | *args, |
| 489 | **kwargs, | 487 | **kwargs, |
| 490 | ): | 488 | ): |
| 491 | super().__init__(*args, **kwargs) | 489 | super().__init__(*args, **kwargs) |
| 492 | 490 | ||
| 493 | self.ema_embeddings = ema_embeddings | 491 | self.ema_embeddings = ema_embeddings |
| 492 | self.placeholder_tokens = placeholder_tokens | ||
| 493 | self.placeholder_token_ids = placeholder_token_ids | ||
| 494 | 494 | ||
| 495 | @torch.no_grad() | 495 | @torch.no_grad() |
| 496 | def checkpoint(self, step, postfix): | 496 | def checkpoint(self, step, postfix): |
diff --git a/trainer/base.py b/trainer_old/base.py index 1f85e71..1f85e71 100644 --- a/trainer/base.py +++ b/trainer_old/base.py | |||
diff --git a/trainer/dreambooth.py b/trainer_old/dreambooth.py index e69de29..e69de29 100644 --- a/trainer/dreambooth.py +++ b/trainer_old/dreambooth.py | |||
diff --git a/trainer/ti.py b/trainer_old/ti.py index 388acd3..66393af 100644 --- a/trainer/ti.py +++ b/trainer_old/ti.py | |||
| @@ -15,12 +15,16 @@ class TextualInversionCheckpointer(Checkpointer): | |||
| 15 | def __init__( | 15 | def __init__( |
| 16 | self, | 16 | self, |
| 17 | ema_embeddings: EMAModel, | 17 | ema_embeddings: EMAModel, |
| 18 | placeholder_tokens: list[str], | ||
| 19 | placeholder_token_ids: list[list[int]], | ||
| 18 | *args, | 20 | *args, |
| 19 | **kwargs, | 21 | **kwargs, |
| 20 | ): | 22 | ): |
| 21 | super().__init__(*args, **kwargs) | 23 | super().__init__(*args, **kwargs) |
| 22 | 24 | ||
| 23 | self.ema_embeddings = ema_embeddings | 25 | self.ema_embeddings = ema_embeddings |
| 26 | self.placeholder_tokens = placeholder_tokens | ||
| 27 | self.placeholder_token_ids = placeholder_token_ids | ||
| 24 | 28 | ||
| 25 | @torch.no_grad() | 29 | @torch.no_grad() |
| 26 | def checkpoint(self, step, postfix): | 30 | def checkpoint(self, step, postfix): |
diff --git a/training/functional.py b/training/functional.py index c100ea2..c5b514a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -25,17 +25,31 @@ def const(result=None): | |||
| 25 | return fn | 25 | return fn |
| 26 | 26 | ||
| 27 | 27 | ||
| 28 | def get_models(pretrained_model_name_or_path: str): | ||
| 29 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
| 30 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 31 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 32 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 33 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 34 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 35 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 36 | |||
| 37 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 38 | |||
| 39 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 40 | |||
| 41 | |||
| 28 | def generate_class_images( | 42 | def generate_class_images( |
| 29 | accelerator, | 43 | accelerator: Accelerator, |
| 30 | text_encoder, | 44 | text_encoder: CLIPTextModel, |
| 31 | vae, | 45 | vae: AutoencoderKL, |
| 32 | unet, | 46 | unet: UNet2DConditionModel, |
| 33 | tokenizer, | 47 | tokenizer: MultiCLIPTokenizer, |
| 34 | scheduler, | 48 | sample_scheduler: DPMSolverMultistepScheduler, |
| 35 | data_train, | 49 | data_train, |
| 36 | sample_batch_size, | 50 | sample_batch_size: int, |
| 37 | sample_image_size, | 51 | sample_image_size: int, |
| 38 | sample_steps | 52 | sample_steps: int |
| 39 | ): | 53 | ): |
| 40 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 54 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
| 41 | 55 | ||
| @@ -52,7 +66,7 @@ def generate_class_images( | |||
| 52 | vae=vae, | 66 | vae=vae, |
| 53 | unet=unet, | 67 | unet=unet, |
| 54 | tokenizer=tokenizer, | 68 | tokenizer=tokenizer, |
| 55 | scheduler=scheduler, | 69 | scheduler=sample_scheduler, |
| 56 | ).to(accelerator.device) | 70 | ).to(accelerator.device) |
| 57 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 71 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 58 | 72 | ||
diff --git a/training/util.py b/training/util.py index a292edd..f46cc61 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer | |||
| 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 15 | 15 | ||
| 16 | 16 | ||
| 17 | class TrainingStrategy(): | ||
| 18 | @property | ||
| 19 | def main_model(self) -> torch.nn.Module: | ||
| 20 | ... | ||
| 21 | |||
| 22 | @contextmanager | ||
| 23 | def on_train(self, epoch: int): | ||
| 24 | yield | ||
| 25 | |||
| 26 | @contextmanager | ||
| 27 | def on_eval(self): | ||
| 28 | yield | ||
| 29 | |||
| 30 | def on_before_optimize(self, epoch: int): | ||
| 31 | ... | ||
| 32 | |||
| 33 | def on_after_optimize(self, lr: float): | ||
| 34 | ... | ||
| 35 | |||
| 36 | def on_log(): | ||
| 37 | return {} | ||
| 38 | |||
| 39 | |||
| 40 | def save_args(basepath: Path, args, extra={}): | 17 | def save_args(basepath: Path, args, extra={}): |
| 41 | info = {"args": vars(args)} | 18 | info = {"args": vars(args)} |
| 42 | info["args"].update(extra) | 19 | info["args"].update(extra) |
| @@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}): | |||
| 44 | json.dump(info, f, indent=4) | 21 | json.dump(info, f, indent=4) |
| 45 | 22 | ||
| 46 | 23 | ||
| 47 | def generate_class_images( | ||
| 48 | accelerator, | ||
| 49 | text_encoder, | ||
| 50 | vae, | ||
| 51 | unet, | ||
| 52 | tokenizer, | ||
| 53 | scheduler, | ||
| 54 | data_train, | ||
| 55 | sample_batch_size, | ||
| 56 | sample_image_size, | ||
| 57 | sample_steps | ||
| 58 | ): | ||
| 59 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
| 60 | |||
| 61 | if len(missing_data) == 0: | ||
| 62 | return | ||
| 63 | |||
| 64 | batched_data = [ | ||
| 65 | missing_data[i:i+sample_batch_size] | ||
| 66 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 67 | ] | ||
| 68 | |||
| 69 | pipeline = VlpnStableDiffusion( | ||
| 70 | text_encoder=text_encoder, | ||
| 71 | vae=vae, | ||
| 72 | unet=unet, | ||
| 73 | tokenizer=tokenizer, | ||
| 74 | scheduler=scheduler, | ||
| 75 | ).to(accelerator.device) | ||
| 76 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 77 | |||
| 78 | with torch.inference_mode(): | ||
| 79 | for batch in batched_data: | ||
| 80 | image_name = [item.class_image_path for item in batch] | ||
| 81 | prompt = [item.cprompt for item in batch] | ||
| 82 | nprompt = [item.nprompt for item in batch] | ||
| 83 | |||
| 84 | images = pipeline( | ||
| 85 | prompt=prompt, | ||
| 86 | negative_prompt=nprompt, | ||
| 87 | height=sample_image_size, | ||
| 88 | width=sample_image_size, | ||
| 89 | num_inference_steps=sample_steps | ||
| 90 | ).images | ||
| 91 | |||
| 92 | for i, image in enumerate(images): | ||
| 93 | image.save(image_name[i]) | ||
| 94 | |||
| 95 | del pipeline | ||
| 96 | |||
| 97 | if torch.cuda.is_available(): | ||
| 98 | torch.cuda.empty_cache() | ||
| 99 | |||
| 100 | |||
| 101 | def get_models(pretrained_model_name_or_path: str): | ||
| 102 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
| 103 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 104 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 105 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 106 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 107 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 108 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 109 | |||
| 110 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 111 | |||
| 112 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 113 | |||
| 114 | |||
| 115 | def add_placeholder_tokens( | ||
| 116 | tokenizer: MultiCLIPTokenizer, | ||
| 117 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 118 | placeholder_tokens: list[str], | ||
| 119 | initializer_tokens: list[str], | ||
| 120 | num_vectors: Union[list[int], int] | ||
| 121 | ): | ||
| 122 | initializer_token_ids = [ | ||
| 123 | tokenizer.encode(token, add_special_tokens=False) | ||
| 124 | for token in initializer_tokens | ||
| 125 | ] | ||
| 126 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
| 127 | |||
| 128 | embeddings.resize(len(tokenizer)) | ||
| 129 | |||
| 130 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | ||
| 131 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
| 132 | |||
| 133 | return placeholder_token_ids, initializer_token_ids | ||
| 134 | |||
| 135 | |||
| 136 | class AverageMeter: | 24 | class AverageMeter: |
| 137 | def __init__(self, name=None): | 25 | def __init__(self, name=None): |
| 138 | self.name = name | 26 | self.name = name |
