summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py10
-rw-r--r--trainer_old/base.py (renamed from trainer/base.py)0
-rw-r--r--trainer_old/dreambooth.py (renamed from trainer/dreambooth.py)0
-rw-r--r--trainer_old/ti.py (renamed from trainer/ti.py)4
-rw-r--r--training/functional.py34
-rw-r--r--training/util.py112
6 files changed, 33 insertions, 127 deletions
diff --git a/train_ti.py b/train_ti.py
index a4e2dde..78c1b5c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -11,20 +11,16 @@ import torch.utils.checkpoint
11from accelerate import Accelerator 11from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from diffusers import AutoencoderKL, UNet2DConditionModel
15import matplotlib.pyplot as plt 14import matplotlib.pyplot as plt
16from transformers import CLIPTextModel
17from slugify import slugify 15from slugify import slugify
18 16
19from util import load_config, load_embeddings_from_dir 17from util import load_config, load_embeddings_from_dir
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from data.csv import VlpnDataModule, VlpnDataItem 18from data.csv import VlpnDataModule, VlpnDataItem
22from trainer.base import Checkpointer 19from trainer_old.base import Checkpointer
23from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 20from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
24from training.optimization import get_scheduler 21from training.optimization import get_scheduler
25from training.lr import LRFinder 22from training.lr import LRFinder
26from training.util import EMAModel, save_args 23from training.util import EMAModel, save_args
27from models.clip.tokenizer import MultiCLIPTokenizer
28 24
29logger = get_logger(__name__) 25logger = get_logger(__name__)
30 26
@@ -485,12 +481,16 @@ class TextualInversionCheckpointer(Checkpointer):
485 def __init__( 481 def __init__(
486 self, 482 self,
487 ema_embeddings: EMAModel, 483 ema_embeddings: EMAModel,
484 placeholder_tokens: list[str],
485 placeholder_token_ids: list[list[int]],
488 *args, 486 *args,
489 **kwargs, 487 **kwargs,
490 ): 488 ):
491 super().__init__(*args, **kwargs) 489 super().__init__(*args, **kwargs)
492 490
493 self.ema_embeddings = ema_embeddings 491 self.ema_embeddings = ema_embeddings
492 self.placeholder_tokens = placeholder_tokens
493 self.placeholder_token_ids = placeholder_token_ids
494 494
495 @torch.no_grad() 495 @torch.no_grad()
496 def checkpoint(self, step, postfix): 496 def checkpoint(self, step, postfix):
diff --git a/trainer/base.py b/trainer_old/base.py
index 1f85e71..1f85e71 100644
--- a/trainer/base.py
+++ b/trainer_old/base.py
diff --git a/trainer/dreambooth.py b/trainer_old/dreambooth.py
index e69de29..e69de29 100644
--- a/trainer/dreambooth.py
+++ b/trainer_old/dreambooth.py
diff --git a/trainer/ti.py b/trainer_old/ti.py
index 388acd3..66393af 100644
--- a/trainer/ti.py
+++ b/trainer_old/ti.py
@@ -15,12 +15,16 @@ class TextualInversionCheckpointer(Checkpointer):
15 def __init__( 15 def __init__(
16 self, 16 self,
17 ema_embeddings: EMAModel, 17 ema_embeddings: EMAModel,
18 placeholder_tokens: list[str],
19 placeholder_token_ids: list[list[int]],
18 *args, 20 *args,
19 **kwargs, 21 **kwargs,
20 ): 22 ):
21 super().__init__(*args, **kwargs) 23 super().__init__(*args, **kwargs)
22 24
23 self.ema_embeddings = ema_embeddings 25 self.ema_embeddings = ema_embeddings
26 self.placeholder_tokens = placeholder_tokens
27 self.placeholder_token_ids = placeholder_token_ids
24 28
25 @torch.no_grad() 29 @torch.no_grad()
26 def checkpoint(self, step, postfix): 30 def checkpoint(self, step, postfix):
diff --git a/training/functional.py b/training/functional.py
index c100ea2..c5b514a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -25,17 +25,31 @@ def const(result=None):
25 return fn 25 return fn
26 26
27 27
28def get_models(pretrained_model_name_or_path: str):
29 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
30 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
31 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
32 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
33 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
34 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
35 pretrained_model_name_or_path, subfolder='scheduler')
36
37 embeddings = patch_managed_embeddings(text_encoder)
38
39 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
40
41
28def generate_class_images( 42def generate_class_images(
29 accelerator, 43 accelerator: Accelerator,
30 text_encoder, 44 text_encoder: CLIPTextModel,
31 vae, 45 vae: AutoencoderKL,
32 unet, 46 unet: UNet2DConditionModel,
33 tokenizer, 47 tokenizer: MultiCLIPTokenizer,
34 scheduler, 48 sample_scheduler: DPMSolverMultistepScheduler,
35 data_train, 49 data_train,
36 sample_batch_size, 50 sample_batch_size: int,
37 sample_image_size, 51 sample_image_size: int,
38 sample_steps 52 sample_steps: int
39): 53):
40 missing_data = [item for item in data_train if not item.class_image_path.exists()] 54 missing_data = [item for item in data_train if not item.class_image_path.exists()]
41 55
@@ -52,7 +66,7 @@ def generate_class_images(
52 vae=vae, 66 vae=vae,
53 unet=unet, 67 unet=unet,
54 tokenizer=tokenizer, 68 tokenizer=tokenizer,
55 scheduler=scheduler, 69 scheduler=sample_scheduler,
56 ).to(accelerator.device) 70 ).to(accelerator.device)
57 pipeline.set_progress_bar_config(dynamic_ncols=True) 71 pipeline.set_progress_bar_config(dynamic_ncols=True)
58 72
diff --git a/training/util.py b/training/util.py
index a292edd..f46cc61 100644
--- a/training/util.py
+++ b/training/util.py
@@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15 15
16 16
17class TrainingStrategy():
18 @property
19 def main_model(self) -> torch.nn.Module:
20 ...
21
22 @contextmanager
23 def on_train(self, epoch: int):
24 yield
25
26 @contextmanager
27 def on_eval(self):
28 yield
29
30 def on_before_optimize(self, epoch: int):
31 ...
32
33 def on_after_optimize(self, lr: float):
34 ...
35
36 def on_log():
37 return {}
38
39
40def save_args(basepath: Path, args, extra={}): 17def save_args(basepath: Path, args, extra={}):
41 info = {"args": vars(args)} 18 info = {"args": vars(args)}
42 info["args"].update(extra) 19 info["args"].update(extra)
@@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}):
44 json.dump(info, f, indent=4) 21 json.dump(info, f, indent=4)
45 22
46 23
47def generate_class_images(
48 accelerator,
49 text_encoder,
50 vae,
51 unet,
52 tokenizer,
53 scheduler,
54 data_train,
55 sample_batch_size,
56 sample_image_size,
57 sample_steps
58):
59 missing_data = [item for item in data_train if not item.class_image_path.exists()]
60
61 if len(missing_data) == 0:
62 return
63
64 batched_data = [
65 missing_data[i:i+sample_batch_size]
66 for i in range(0, len(missing_data), sample_batch_size)
67 ]
68
69 pipeline = VlpnStableDiffusion(
70 text_encoder=text_encoder,
71 vae=vae,
72 unet=unet,
73 tokenizer=tokenizer,
74 scheduler=scheduler,
75 ).to(accelerator.device)
76 pipeline.set_progress_bar_config(dynamic_ncols=True)
77
78 with torch.inference_mode():
79 for batch in batched_data:
80 image_name = [item.class_image_path for item in batch]
81 prompt = [item.cprompt for item in batch]
82 nprompt = [item.nprompt for item in batch]
83
84 images = pipeline(
85 prompt=prompt,
86 negative_prompt=nprompt,
87 height=sample_image_size,
88 width=sample_image_size,
89 num_inference_steps=sample_steps
90 ).images
91
92 for i, image in enumerate(images):
93 image.save(image_name[i])
94
95 del pipeline
96
97 if torch.cuda.is_available():
98 torch.cuda.empty_cache()
99
100
101def get_models(pretrained_model_name_or_path: str):
102 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
103 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
104 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
105 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
106 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
107 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
108 pretrained_model_name_or_path, subfolder='scheduler')
109
110 embeddings = patch_managed_embeddings(text_encoder)
111
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113
114
115def add_placeholder_tokens(
116 tokenizer: MultiCLIPTokenizer,
117 embeddings: ManagedCLIPTextEmbeddings,
118 placeholder_tokens: list[str],
119 initializer_tokens: list[str],
120 num_vectors: Union[list[int], int]
121):
122 initializer_token_ids = [
123 tokenizer.encode(token, add_special_tokens=False)
124 for token in initializer_tokens
125 ]
126 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
127
128 embeddings.resize(len(tokenizer))
129
130 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
131 embeddings.add_embed(placeholder_token_id, initializer_token_id)
132
133 return placeholder_token_ids, initializer_token_ids
134
135
136class AverageMeter: 24class AverageMeter:
137 def __init__(self, name=None): 25 def __init__(self, name=None):
138 self.name = name 26 self.name = name