summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py36
-rw-r--r--models/clip/embeddings.py6
-rw-r--r--models/clip/prompt.py38
-rw-r--r--models/clip/util.py34
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py20
-rw-r--r--train_dreambooth.py7
-rw-r--r--train_ti.py268
-rw-r--r--training/common.py205
-rw-r--r--training/util.py13
9 files changed, 334 insertions, 293 deletions
diff --git a/data/csv.py b/data/csv.py
index f5fc8e6..a3fef30 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -9,9 +9,10 @@ from PIL import Image
9 9
10from torch.utils.data import IterableDataset, DataLoader, random_split 10from torch.utils.data import IterableDataset, DataLoader, random_split
11from torchvision import transforms 11from torchvision import transforms
12from transformers import CLIPTokenizer
12 13
13from data.keywords import prompt_to_keywords, keywords_to_prompt 14from data.keywords import prompt_to_keywords, keywords_to_prompt
14from models.clip.prompt import PromptProcessor 15from models.clip.util import unify_input_ids
15 16
16 17
17image_cache: dict[str, Image.Image] = {} 18image_cache: dict[str, Image.Image] = {}
@@ -102,7 +103,7 @@ def generate_buckets(
102def collate_fn( 103def collate_fn(
103 num_class_images: int, 104 num_class_images: int,
104 weight_dtype: torch.dtype, 105 weight_dtype: torch.dtype,
105 prompt_processor: PromptProcessor, 106 tokenizer: CLIPTokenizer,
106 examples 107 examples
107): 108):
108 prompt_ids = [example["prompt_ids"] for example in examples] 109 prompt_ids = [example["prompt_ids"] for example in examples]
@@ -119,9 +120,9 @@ def collate_fn(
119 pixel_values = torch.stack(pixel_values) 120 pixel_values = torch.stack(pixel_values)
120 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 121 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
121 122
122 prompts = prompt_processor.unify_input_ids(prompt_ids) 123 prompts = unify_input_ids(tokenizer, prompt_ids)
123 nprompts = prompt_processor.unify_input_ids(nprompt_ids) 124 nprompts = unify_input_ids(tokenizer, nprompt_ids)
124 inputs = prompt_processor.unify_input_ids(input_ids) 125 inputs = unify_input_ids(tokenizer, input_ids)
125 126
126 batch = { 127 batch = {
127 "prompt_ids": prompts.input_ids, 128 "prompt_ids": prompts.input_ids,
@@ -148,7 +149,7 @@ class VlpnDataModule():
148 self, 149 self,
149 batch_size: int, 150 batch_size: int,
150 data_file: str, 151 data_file: str,
151 prompt_processor: PromptProcessor, 152 tokenizer: CLIPTokenizer,
152 class_subdir: str = "cls", 153 class_subdir: str = "cls",
153 num_class_images: int = 1, 154 num_class_images: int = 1,
154 size: int = 768, 155 size: int = 768,
@@ -179,7 +180,7 @@ class VlpnDataModule():
179 self.class_root.mkdir(parents=True, exist_ok=True) 180 self.class_root.mkdir(parents=True, exist_ok=True)
180 self.num_class_images = num_class_images 181 self.num_class_images = num_class_images
181 182
182 self.prompt_processor = prompt_processor 183 self.tokenizer = tokenizer
183 self.size = size 184 self.size = size
184 self.num_buckets = num_buckets 185 self.num_buckets = num_buckets
185 self.bucket_step_size = bucket_step_size 186 self.bucket_step_size = bucket_step_size
@@ -272,7 +273,7 @@ class VlpnDataModule():
272 self.data_val = self.pad_items(data_val) 273 self.data_val = self.pad_items(data_val)
273 274
274 train_dataset = VlpnDataset( 275 train_dataset = VlpnDataset(
275 self.data_train, self.prompt_processor, 276 self.data_train, self.tokenizer,
276 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 277 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
277 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 278 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
278 batch_size=self.batch_size, generator=generator, 279 batch_size=self.batch_size, generator=generator,
@@ -281,7 +282,7 @@ class VlpnDataModule():
281 ) 282 )
282 283
283 val_dataset = VlpnDataset( 284 val_dataset = VlpnDataset(
284 self.data_val, self.prompt_processor, 285 self.data_val, self.tokenizer,
285 num_buckets=self.num_buckets, progressive_buckets=True, 286 num_buckets=self.num_buckets, progressive_buckets=True,
286 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 287 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
287 repeat=self.valid_set_repeat, 288 repeat=self.valid_set_repeat,
@@ -289,7 +290,7 @@ class VlpnDataModule():
289 size=self.size, interpolation=self.interpolation, 290 size=self.size, interpolation=self.interpolation,
290 ) 291 )
291 292
292 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) 293 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer)
293 294
294 self.train_dataloader = DataLoader( 295 self.train_dataloader = DataLoader(
295 train_dataset, 296 train_dataset,
@@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset):
306 def __init__( 307 def __init__(
307 self, 308 self,
308 items: list[VlpnDataItem], 309 items: list[VlpnDataItem],
309 prompt_processor: PromptProcessor, 310 tokenizer: CLIPTokenizer,
310 num_buckets: int = 1, 311 num_buckets: int = 1,
311 bucket_step_size: int = 64, 312 bucket_step_size: int = 64,
312 bucket_max_pixels: Optional[int] = None, 313 bucket_max_pixels: Optional[int] = None,
@@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset):
323 self.items = items * repeat 324 self.items = items * repeat
324 self.batch_size = batch_size 325 self.batch_size = batch_size
325 326
326 self.prompt_processor = prompt_processor 327 self.tokenizer = tokenizer
327 self.num_class_images = num_class_images 328 self.num_class_images = num_class_images
328 self.size = size 329 self.size = size
329 self.dropout = dropout 330 self.dropout = dropout
@@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset):
344 345
345 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 346 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
346 347
348 def get_input_ids(self, text: str):
349 return self.tokenizer(text, padding="do_not_pad").input_ids
350
347 def __len__(self): 351 def __len__(self):
348 return self.length_ 352 return self.length_
349 353
@@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset):
404 408
405 example = {} 409 example = {}
406 410
407 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) 411 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt))
408 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) 412 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
409 413
410 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 414 example["instance_prompt_ids"] = self.get_input_ids(
411 keywords_to_prompt(item.prompt, self.dropout, True) 415 keywords_to_prompt(item.prompt, self.dropout, True)
412 ) 416 )
413 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 417 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
414 418
415 if self.num_class_images != 0: 419 if self.num_class_images != 0:
416 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) 420 example["class_prompt_ids"] = self.get_input_ids(item.cprompt)
417 example["class_images"] = image_transforms(get_image(item.class_image_path)) 421 example["class_images"] = image_transforms(get_image(item.class_image_path))
418 422
419 batch.append(example) 423 batch.append(example)
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9a23a2a..761efbc 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 self.position_embedding = embeddings.position_embedding 40 self.position_embedding = embeddings.position_embedding
41 self.initializer_factor = config.initializer_factor 41 self.initializer_factor = config.initializer_factor
42 42
43 self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item()
44
43 self.temp_token_embedding = nn.Embedding( 45 self.temp_token_embedding = nn.Embedding(
44 self.token_embedding.num_embeddings, 46 self.token_embedding.num_embeddings,
45 self.token_embedding.embedding_dim, 47 self.token_embedding.embedding_dim,
@@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
99 101
100 return embeds 102 return embeds
101 103
102 def normalize(self, target: float = 0.4, lambda_: float = 1.0): 104 def normalize(self, target: Optional[float] = None, lambda_: float = 1.0):
105 if target is None:
106 target = self.decay_target
103 w = self.temp_token_embedding.weight 107 w = self.temp_token_embedding.weight
104 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) 108 pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
105 w[self.temp_token_ids] = F.normalize( 109 w[self.temp_token_ids] = F.normalize(
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
deleted file mode 100644
index a7380be..0000000
--- a/models/clip/prompt.py
+++ /dev/null
@@ -1,38 +0,0 @@
1from typing import Union, Optional
2
3import torch
4
5from transformers import CLIPTokenizer, CLIPTextModel
6
7
8class PromptProcessor():
9 def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel):
10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder
12
13 def get_input_ids(self, prompt: Union[str, list[str]]):
14 return self.tokenizer(
15 prompt,
16 padding="do_not_pad",
17 ).input_ids
18
19 def unify_input_ids(self, input_ids: list[list[int]]):
20 return self.tokenizer.pad(
21 {"input_ids": input_ids},
22 padding=True,
23 pad_to_multiple_of=self.tokenizer.model_max_length,
24 return_tensors="pt"
25 )
26
27 def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None):
28 prompts = input_ids.shape[0]
29
30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if position_ids is not None:
32 position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33 if attention_mask is not None:
34 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
35
36 text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
37 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
38 return text_embeddings
diff --git a/models/clip/util.py b/models/clip/util.py
new file mode 100644
index 0000000..8de8c19
--- /dev/null
+++ b/models/clip/util.py
@@ -0,0 +1,34 @@
1from typing import Optional
2
3import torch
4
5from transformers import CLIPTokenizer, CLIPTextModel
6
7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]):
9 return tokenizer.pad(
10 {"input_ids": input_ids},
11 padding=True,
12 pad_to_multiple_of=tokenizer.model_max_length,
13 return_tensors="pt"
14 )
15
16
17def get_extended_embeddings(
18 text_encoder: CLIPTextModel,
19 input_ids: torch.LongTensor,
20 position_ids: Optional[torch.LongTensor] = None,
21 attention_mask=None
22):
23 model_max_length = text_encoder.config.max_position_embeddings
24 prompts = input_ids.shape[0]
25
26 input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device)
27 if position_ids is not None:
28 position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device)
29 if attention_mask is not None:
30 attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device)
31
32 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
33 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
34 return text_embeddings
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 6bc40e9..a5cfc60 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -22,7 +22,7 @@ from diffusers import (
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.utils import logging, randn_tensor 23from diffusers.utils import logging, randn_tensor
24from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
25from models.clip.prompt import PromptProcessor 25from models.clip.util import unify_input_ids, get_extended_embeddings
26 26
27logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28 28
@@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
70 new_config["steps_offset"] = 1 70 new_config["steps_offset"] = 1
71 scheduler._internal_dict = FrozenDict(new_config) 71 scheduler._internal_dict = FrozenDict(new_config)
72 72
73 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
74
75 self.register_modules( 73 self.register_modules(
76 vae=vae, 74 vae=vae,
77 text_encoder=text_encoder, 75 text_encoder=text_encoder,
@@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline):
213 do_classifier_free_guidance: bool, 211 do_classifier_free_guidance: bool,
214 device 212 device
215 ): 213 ):
216 text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt 214 if isinstance(prompt[0], str):
215 text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids
216 else:
217 text_input_ids = prompt
218
217 text_input_ids *= num_images_per_prompt 219 text_input_ids *= num_images_per_prompt
218 220
219 if do_classifier_free_guidance: 221 if do_classifier_free_guidance:
220 unconditional_input_ids = self.prompt_processor.get_input_ids( 222 if isinstance(prompt[0], str):
221 negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt 223 unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids
224 else:
225 unconditional_input_ids = negative_prompt
222 unconditional_input_ids *= num_images_per_prompt 226 unconditional_input_ids *= num_images_per_prompt
223 text_input_ids = unconditional_input_ids + text_input_ids 227 text_input_ids = unconditional_input_ids + text_input_ids
224 228
225 text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) 229 text_inputs = unify_input_ids(self.tokenizer, text_input_ids)
226 text_input_ids = text_inputs.input_ids 230 text_input_ids = text_inputs.input_ids
227 231
228 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 232 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
@@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
230 else: 234 else:
231 attention_mask = None 235 attention_mask = None
232 236
233 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) 237 text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask)
234 238
235 return text_embeddings 239 return text_embeddings
236 240
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0fe590f..fbbe6c2 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -27,7 +27,6 @@ from training.common import loss_step, generate_class_images, get_scheduler
27from training.lr import LRFinder 27from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
30from models.clip.prompt import PromptProcessor
31from models.clip.tokenizer import MultiCLIPTokenizer 30from models.clip.tokenizer import MultiCLIPTokenizer
32 31
33logger = get_logger(__name__) 32logger = get_logger(__name__)
@@ -690,8 +689,6 @@ def main():
690 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 689 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
691 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) 690 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
692 691
693 prompt_processor = PromptProcessor(tokenizer, text_encoder)
694
695 if args.scale_lr: 692 if args.scale_lr:
696 args.learning_rate = ( 693 args.learning_rate = (
697 args.learning_rate * args.gradient_accumulation_steps * 694 args.learning_rate * args.gradient_accumulation_steps *
@@ -751,7 +748,7 @@ def main():
751 datamodule = VlpnDataModule( 748 datamodule = VlpnDataModule(
752 data_file=args.train_data_file, 749 data_file=args.train_data_file,
753 batch_size=args.train_batch_size, 750 batch_size=args.train_batch_size,
754 prompt_processor=prompt_processor, 751 tokenizer=tokenizer,
755 class_subdir=args.class_image_dir, 752 class_subdir=args.class_image_dir,
756 num_class_images=args.num_class_images, 753 num_class_images=args.num_class_images,
757 size=args.resolution, 754 size=args.resolution,
@@ -876,7 +873,7 @@ def main():
876 vae, 873 vae,
877 noise_scheduler, 874 noise_scheduler,
878 unet, 875 unet,
879 prompt_processor, 876 text_encoder,
880 args.num_class_images, 877 args.num_class_images,
881 args.prior_loss_weight, 878 args.prior_loss_weight,
882 args.seed, 879 args.seed,
diff --git a/train_ti.py b/train_ti.py
index e18ee38..8c86586 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -21,11 +21,10 @@ from slugify import slugify
21from util import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import VlpnDataModule, VlpnDataItem 23from data.csv import VlpnDataModule, VlpnDataItem
24from training.common import loss_step, generate_class_images, get_scheduler 24from training.common import loss_step, train_loop, generate_class_images, get_scheduler
25from training.lr import LRFinder 25from training.lr import LRFinder
26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args 26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
27from models.clip.embeddings import patch_managed_embeddings 27from models.clip.embeddings import patch_managed_embeddings
28from models.clip.prompt import PromptProcessor
29from models.clip.tokenizer import MultiCLIPTokenizer 28from models.clip.tokenizer import MultiCLIPTokenizer
30 29
31logger = get_logger(__name__) 30logger = get_logger(__name__)
@@ -198,12 +197,6 @@ def parse_args():
198 default=100 197 default=100
199 ) 198 )
200 parser.add_argument( 199 parser.add_argument(
201 "--max_train_steps",
202 type=int,
203 default=None,
204 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
205 )
206 parser.add_argument(
207 "--gradient_accumulation_steps", 200 "--gradient_accumulation_steps",
208 type=int, 201 type=int,
209 default=1, 202 default=1,
@@ -409,7 +402,7 @@ def parse_args():
409 ) 402 )
410 parser.add_argument( 403 parser.add_argument(
411 "--decay_target", 404 "--decay_target",
412 default=0.4, 405 default=None,
413 type=float, 406 type=float,
414 help="Embedding decay target." 407 help="Embedding decay target."
415 ) 408 )
@@ -668,8 +661,6 @@ def main():
668 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 661 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
669 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) 662 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
670 663
671 prompt_processor = PromptProcessor(tokenizer, text_encoder)
672
673 if args.scale_lr: 664 if args.scale_lr:
674 args.learning_rate = ( 665 args.learning_rate = (
675 args.learning_rate * args.gradient_accumulation_steps * 666 args.learning_rate * args.gradient_accumulation_steps *
@@ -722,7 +713,7 @@ def main():
722 datamodule = VlpnDataModule( 713 datamodule = VlpnDataModule(
723 data_file=args.train_data_file, 714 data_file=args.train_data_file,
724 batch_size=args.train_batch_size, 715 batch_size=args.train_batch_size,
725 prompt_processor=prompt_processor, 716 tokenizer=tokenizer,
726 class_subdir=args.class_image_dir, 717 class_subdir=args.class_image_dir,
727 num_class_images=args.num_class_images, 718 num_class_images=args.num_class_images,
728 size=args.resolution, 719 size=args.resolution,
@@ -759,13 +750,7 @@ def main():
759 args.sample_steps 750 args.sample_steps
760 ) 751 )
761 752
762 # Scheduler and math around the number of training steps.
763 overrode_max_train_steps = False
764 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 753 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
765 if args.max_train_steps is None:
766 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
767 overrode_max_train_steps = True
768 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
769 754
770 if args.find_lr: 755 if args.find_lr:
771 lr_scheduler = None 756 lr_scheduler = None
@@ -781,7 +766,7 @@ def main():
781 annealing_exp=args.lr_annealing_exp, 766 annealing_exp=args.lr_annealing_exp,
782 cycles=args.lr_cycles, 767 cycles=args.lr_cycles,
783 warmup_epochs=args.lr_warmup_epochs, 768 warmup_epochs=args.lr_warmup_epochs,
784 max_train_steps=args.max_train_steps, 769 num_train_epochs=args.num_train_epochs,
785 num_update_steps_per_epoch=num_update_steps_per_epoch, 770 num_update_steps_per_epoch=num_update_steps_per_epoch,
786 gradient_accumulation_steps=args.gradient_accumulation_steps 771 gradient_accumulation_steps=args.gradient_accumulation_steps
787 ) 772 )
@@ -805,15 +790,6 @@ def main():
805 else: 790 else:
806 unet.eval() 791 unet.eval()
807 792
808 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
809 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
810 if overrode_max_train_steps:
811 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
812
813 num_val_steps_per_epoch = len(val_dataloader)
814 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
815 val_steps = num_val_steps_per_epoch * num_epochs
816
817 @contextmanager 793 @contextmanager
818 def on_train(): 794 def on_train():
819 try: 795 try:
@@ -842,19 +818,44 @@ def main():
842 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) 818 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
843 ) 819 )
844 820
821 if args.use_ema:
822 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
823
824 def on_log():
825 if args.use_ema:
826 return {"ema_decay": ema_embeddings.decay}
827 return {}
828
845 loop = partial( 829 loop = partial(
846 loss_step, 830 loss_step,
847 vae, 831 vae,
848 noise_scheduler, 832 noise_scheduler,
849 unet, 833 unet,
850 prompt_processor, 834 text_encoder,
851 args.num_class_images, 835 args.num_class_images,
852 args.prior_loss_weight, 836 args.prior_loss_weight,
853 args.seed, 837 args.seed,
854 ) 838 )
855 839
856 # We need to initialize the trackers we use, and also store our configuration. 840 checkpointer = Checkpointer(
857 # The trackers initializes automatically on the main process. 841 weight_dtype=weight_dtype,
842 datamodule=datamodule,
843 accelerator=accelerator,
844 vae=vae,
845 unet=unet,
846 tokenizer=tokenizer,
847 text_encoder=text_encoder,
848 ema_embeddings=ema_embeddings,
849 scheduler=checkpoint_scheduler,
850 placeholder_token=args.placeholder_token,
851 new_ids=new_ids,
852 output_dir=basepath,
853 sample_image_size=args.sample_image_size,
854 sample_batch_size=args.sample_batch_size,
855 sample_batches=args.sample_batches,
856 seed=args.seed
857 )
858
858 if accelerator.is_main_process: 859 if accelerator.is_main_process:
859 config = vars(args).copy() 860 config = vars(args).copy()
860 config["initializer_token"] = " ".join(config["initializer_token"]) 861 config["initializer_token"] = " ".join(config["initializer_token"])
@@ -882,190 +883,27 @@ def main():
882 883
883 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 884 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
884 plt.close() 885 plt.close()
885 886 else:
886 quit() 887 train_loop(
887 888 accelerator=accelerator,
888 # Train! 889 optimizer=optimizer,
889 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 890 lr_scheduler=lr_scheduler,
890 891 model=text_encoder,
891 logger.info("***** Running training *****") 892 checkpointer=checkpointer,
892 logger.info(f" Num Epochs = {num_epochs}") 893 train_dataloader=train_dataloader,
893 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 894 val_dataloader=val_dataloader,
894 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 895 loss_step=loop,
895 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 896 sample_frequency=args.sample_frequency,
896 logger.info(f" Total optimization steps = {args.max_train_steps}") 897 sample_steps=args.sample_steps,
897 # Only show the progress bar once on each machine. 898 checkpoint_frequency=args.checkpoint_frequency,
898 899 global_step_offset=global_step_offset,
899 global_step = 0 900 gradient_accumulation_steps=args.gradient_accumulation_steps,
900 901 num_epochs=args.num_train_epochs,
901 avg_loss = AverageMeter() 902 on_log=on_log,
902 avg_acc = AverageMeter() 903 on_train=on_train,
903 904 on_after_optimize=on_after_optimize,
904 avg_loss_val = AverageMeter() 905 on_eval=on_eval
905 avg_acc_val = AverageMeter() 906 )
906
907 max_acc_val = 0.0
908
909 checkpointer = Checkpointer(
910 weight_dtype=weight_dtype,
911 datamodule=datamodule,
912 accelerator=accelerator,
913 vae=vae,
914 unet=unet,
915 tokenizer=tokenizer,
916 text_encoder=text_encoder,
917 ema_embeddings=ema_embeddings,
918 scheduler=checkpoint_scheduler,
919 placeholder_token=args.placeholder_token,
920 new_ids=new_ids,
921 output_dir=basepath,
922 sample_image_size=args.sample_image_size,
923 sample_batch_size=args.sample_batch_size,
924 sample_batches=args.sample_batches,
925 seed=args.seed
926 )
927
928 local_progress_bar = tqdm(
929 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
930 disable=not accelerator.is_local_main_process,
931 dynamic_ncols=True
932 )
933 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
934
935 global_progress_bar = tqdm(
936 range(args.max_train_steps + val_steps),
937 disable=not accelerator.is_local_main_process,
938 dynamic_ncols=True
939 )
940 global_progress_bar.set_description("Total progress")
941
942 try:
943 for epoch in range(num_epochs):
944 if accelerator.is_main_process:
945 if epoch % args.sample_frequency == 0:
946 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
947
948 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
949 local_progress_bar.reset()
950
951 text_encoder.train()
952
953 with on_train():
954 for step, batch in enumerate(train_dataloader):
955 with accelerator.accumulate(text_encoder):
956 loss, acc, bsz = loop(step, batch)
957
958 accelerator.backward(loss)
959
960 optimizer.step()
961 lr_scheduler.step()
962 optimizer.zero_grad(set_to_none=True)
963
964 avg_loss.update(loss.detach_(), bsz)
965 avg_acc.update(acc.detach_(), bsz)
966
967 # Checks if the accelerator has performed an optimization step behind the scenes
968 if accelerator.sync_gradients:
969 on_after_optimize(lr_scheduler.get_last_lr()[0])
970
971 if args.use_ema:
972 ema_embeddings.step(
973 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
974
975 local_progress_bar.update(1)
976 global_progress_bar.update(1)
977
978 global_step += 1
979
980 logs = {
981 "train/loss": avg_loss.avg.item(),
982 "train/acc": avg_acc.avg.item(),
983 "train/cur_loss": loss.item(),
984 "train/cur_acc": acc.item(),
985 "lr": lr_scheduler.get_last_lr()[0],
986 }
987 if args.use_ema:
988 logs["ema_decay"] = ema_embeddings.decay
989
990 accelerator.log(logs, step=global_step)
991
992 local_progress_bar.set_postfix(**logs)
993
994 if global_step >= args.max_train_steps:
995 break
996
997 accelerator.wait_for_everyone()
998
999 text_encoder.eval()
1000
1001 cur_loss_val = AverageMeter()
1002 cur_acc_val = AverageMeter()
1003
1004 with torch.inference_mode():
1005 with on_eval():
1006 for step, batch in enumerate(val_dataloader):
1007 loss, acc, bsz = loop(step, batch, True)
1008
1009 loss = loss.detach_()
1010 acc = acc.detach_()
1011
1012 cur_loss_val.update(loss, bsz)
1013 cur_acc_val.update(acc, bsz)
1014
1015 avg_loss_val.update(loss, bsz)
1016 avg_acc_val.update(acc, bsz)
1017
1018 local_progress_bar.update(1)
1019 global_progress_bar.update(1)
1020
1021 logs = {
1022 "val/loss": avg_loss_val.avg.item(),
1023 "val/acc": avg_acc_val.avg.item(),
1024 "val/cur_loss": loss.item(),
1025 "val/cur_acc": acc.item(),
1026 }
1027 local_progress_bar.set_postfix(**logs)
1028
1029 logs["val/cur_loss"] = cur_loss_val.avg.item()
1030 logs["val/cur_acc"] = cur_acc_val.avg.item()
1031
1032 accelerator.log(logs, step=global_step)
1033
1034 local_progress_bar.clear()
1035 global_progress_bar.clear()
1036
1037 if accelerator.is_main_process:
1038 if avg_acc_val.avg.item() > max_acc_val:
1039 accelerator.print(
1040 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1041 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
1042 max_acc_val = avg_acc_val.avg.item()
1043
1044 if (epoch + 1) % args.checkpoint_frequency == 0:
1045 checkpointer.checkpoint(global_step + global_step_offset, "training")
1046 save_args(basepath, args, {
1047 "global_step": global_step + global_step_offset
1048 })
1049
1050 # Create the pipeline using using the trained modules and save it.
1051 if accelerator.is_main_process:
1052 print("Finished! Saving final checkpoint and resume state.")
1053 checkpointer.checkpoint(global_step + global_step_offset, "end")
1054 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
1055 save_args(basepath, args, {
1056 "global_step": global_step + global_step_offset
1057 })
1058 accelerator.end_training()
1059
1060 except KeyboardInterrupt:
1061 if accelerator.is_main_process:
1062 print("Interrupted, saving checkpoint and resume state...")
1063 checkpointer.checkpoint(global_step + global_step_offset, "end")
1064 save_args(basepath, args, {
1065 "global_step": global_step + global_step_offset
1066 })
1067 accelerator.end_training()
1068 quit()
1069 907
1070 908
1071if __name__ == "__main__": 909if __name__ == "__main__":
diff --git a/training/common.py b/training/common.py
index 90cf910..842ac07 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,14 +1,30 @@
1import math 1import math
2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union
2 4
3import torch 5import torch
4import torch.nn.functional as F 6import torch.nn.functional as F
7from torch.utils.data import DataLoader
5 8
9from accelerate import Accelerator
10from transformers import CLIPTokenizer, CLIPTextModel
6from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
7from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 12from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
8 13
9from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 14from tqdm.auto import tqdm
10 15
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
17from models.clip.util import get_extended_embeddings
11from training.optimization import get_one_cycle_schedule 18from training.optimization import get_one_cycle_schedule
19from training.util import AverageMeter, CheckpointerBase
20
21
22def noop(*args, **kwards):
23 pass
24
25
26def noop_on_log():
27 return {}
12 28
13 29
14def get_scheduler( 30def get_scheduler(
@@ -22,10 +38,11 @@ def get_scheduler(
22 cycles: int, 38 cycles: int,
23 warmup_epochs: int, 39 warmup_epochs: int,
24 optimizer: torch.optim.Optimizer, 40 optimizer: torch.optim.Optimizer,
25 max_train_steps: int, 41 num_train_epochs: int,
26 num_update_steps_per_epoch: int, 42 num_update_steps_per_epoch: int,
27 gradient_accumulation_steps: int, 43 gradient_accumulation_steps: int,
28): 44):
45 num_train_steps = num_train_epochs * num_update_steps_per_epoch
29 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps 46 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps
30 47
31 if id == "one_cycle": 48 if id == "one_cycle":
@@ -33,7 +50,7 @@ def get_scheduler(
33 50
34 lr_scheduler = get_one_cycle_schedule( 51 lr_scheduler = get_one_cycle_schedule(
35 optimizer=optimizer, 52 optimizer=optimizer,
36 num_training_steps=max_train_steps * gradient_accumulation_steps, 53 num_training_steps=num_train_steps * gradient_accumulation_steps,
37 warmup=warmup_func, 54 warmup=warmup_func,
38 annealing=annealing_func, 55 annealing=annealing_func,
39 warmup_exp=warmup_exp, 56 warmup_exp=warmup_exp,
@@ -42,12 +59,12 @@ def get_scheduler(
42 ) 59 )
43 elif id == "cosine_with_restarts": 60 elif id == "cosine_with_restarts":
44 cycles = cycles if cycles is not None else math.ceil( 61 cycles = cycles if cycles is not None else math.ceil(
45 math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) 62 math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch)))
46 63
47 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 64 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
48 optimizer=optimizer, 65 optimizer=optimizer,
49 num_warmup_steps=warmup_steps, 66 num_warmup_steps=warmup_steps,
50 num_training_steps=max_train_steps * gradient_accumulation_steps, 67 num_training_steps=num_train_steps * gradient_accumulation_steps,
51 num_cycles=cycles, 68 num_cycles=cycles,
52 ) 69 )
53 else: 70 else:
@@ -55,7 +72,7 @@ def get_scheduler(
55 id, 72 id,
56 optimizer=optimizer, 73 optimizer=optimizer,
57 num_warmup_steps=warmup_steps, 74 num_warmup_steps=warmup_steps,
58 num_training_steps=max_train_steps * gradient_accumulation_steps, 75 num_training_steps=num_train_steps * gradient_accumulation_steps,
59 ) 76 )
60 77
61 return lr_scheduler 78 return lr_scheduler
@@ -117,12 +134,12 @@ def loss_step(
117 vae: AutoencoderKL, 134 vae: AutoencoderKL,
118 noise_scheduler: DDPMScheduler, 135 noise_scheduler: DDPMScheduler,
119 unet: UNet2DConditionModel, 136 unet: UNet2DConditionModel,
120 prompt_processor, 137 text_encoder: CLIPTextModel,
121 num_class_images: int, 138 num_class_images: int,
122 prior_loss_weight: float, 139 prior_loss_weight: float,
123 seed: int, 140 seed: int,
124 step: int, 141 step: int,
125 batch, 142 batch: dict[str, Any],
126 eval: bool = False 143 eval: bool = False
127): 144):
128 # Convert images to latent space 145 # Convert images to latent space
@@ -149,7 +166,8 @@ def loss_step(
149 noisy_latents = noisy_latents.to(dtype=unet.dtype) 166 noisy_latents = noisy_latents.to(dtype=unet.dtype)
150 167
151 # Get the text embedding for conditioning 168 # Get the text embedding for conditioning
152 encoder_hidden_states = prompt_processor.get_embeddings( 169 encoder_hidden_states = get_extended_embeddings(
170 text_encoder,
153 batch["input_ids"], 171 batch["input_ids"],
154 batch["attention_mask"] 172 batch["attention_mask"]
155 ) 173 )
@@ -185,3 +203,172 @@ def loss_step(
185 acc = (model_pred == target).float().mean() 203 acc = (model_pred == target).float().mean()
186 204
187 return loss, acc, bsz 205 return loss, acc, bsz
206
207
208def train_loop(
209 accelerator: Accelerator,
210 optimizer: torch.optim.Optimizer,
211 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
212 model: torch.nn.Module,
213 checkpointer: CheckpointerBase,
214 train_dataloader: DataLoader,
215 val_dataloader: DataLoader,
216 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
217 sample_frequency: int = 10,
218 sample_steps: int = 20,
219 checkpoint_frequency: int = 50,
220 global_step_offset: int = 0,
221 gradient_accumulation_steps: int = 1,
222 num_epochs: int = 100,
223 on_log: Callable[[], dict[str, Any]] = noop_on_log,
224 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
225 on_before_optimize: Callable[[], None] = noop,
226 on_after_optimize: Callable[[float], None] = noop,
227 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
228):
229 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
230 num_train_steps = num_epochs * num_update_steps_per_epoch
231
232 num_val_steps_per_epoch = len(val_dataloader)
233 num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch)
234 num_val_steps = num_val_steps_per_epoch * num_epochs
235
236 global_step = 0
237
238 avg_loss = AverageMeter()
239 avg_acc = AverageMeter()
240
241 avg_loss_val = AverageMeter()
242 avg_acc_val = AverageMeter()
243
244 max_acc_val = 0.0
245
246 local_progress_bar = tqdm(
247 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
248 disable=not accelerator.is_local_main_process,
249 dynamic_ncols=True
250 )
251 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
252
253 global_progress_bar = tqdm(
254 range(num_train_steps + num_val_steps),
255 disable=not accelerator.is_local_main_process,
256 dynamic_ncols=True
257 )
258 global_progress_bar.set_description("Total progress")
259
260 try:
261 for epoch in range(num_epochs):
262 if accelerator.is_main_process:
263 if epoch % sample_frequency == 0:
264 checkpointer.save_samples(global_step + global_step_offset, sample_steps)
265
266 if epoch % checkpoint_frequency == 0 and epoch != 0:
267 checkpointer.checkpoint(global_step + global_step_offset, "training")
268
269 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
270 local_progress_bar.reset()
271
272 model.train()
273
274 with on_train():
275 for step, batch in enumerate(train_dataloader):
276 with accelerator.accumulate(model):
277 loss, acc, bsz = loss_step(step, batch)
278
279 accelerator.backward(loss)
280
281 on_before_optimize()
282
283 optimizer.step()
284 lr_scheduler.step()
285 optimizer.zero_grad(set_to_none=True)
286
287 avg_loss.update(loss.detach_(), bsz)
288 avg_acc.update(acc.detach_(), bsz)
289
290 # Checks if the accelerator has performed an optimization step behind the scenes
291 if accelerator.sync_gradients:
292 on_after_optimize(lr_scheduler.get_last_lr()[0])
293
294 local_progress_bar.update(1)
295 global_progress_bar.update(1)
296
297 global_step += 1
298
299 logs = {
300 "train/loss": avg_loss.avg.item(),
301 "train/acc": avg_acc.avg.item(),
302 "train/cur_loss": loss.item(),
303 "train/cur_acc": acc.item(),
304 "lr": lr_scheduler.get_last_lr()[0],
305 }
306 logs.update(on_log())
307
308 accelerator.log(logs, step=global_step)
309
310 local_progress_bar.set_postfix(**logs)
311
312 if global_step >= num_train_steps:
313 break
314
315 accelerator.wait_for_everyone()
316
317 model.eval()
318
319 cur_loss_val = AverageMeter()
320 cur_acc_val = AverageMeter()
321
322 with torch.inference_mode():
323 with on_eval():
324 for step, batch in enumerate(val_dataloader):
325 loss, acc, bsz = loss_step(step, batch, True)
326
327 loss = loss.detach_()
328 acc = acc.detach_()
329
330 cur_loss_val.update(loss, bsz)
331 cur_acc_val.update(acc, bsz)
332
333 avg_loss_val.update(loss, bsz)
334 avg_acc_val.update(acc, bsz)
335
336 local_progress_bar.update(1)
337 global_progress_bar.update(1)
338
339 logs = {
340 "val/loss": avg_loss_val.avg.item(),
341 "val/acc": avg_acc_val.avg.item(),
342 "val/cur_loss": loss.item(),
343 "val/cur_acc": acc.item(),
344 }
345 local_progress_bar.set_postfix(**logs)
346
347 logs["val/cur_loss"] = cur_loss_val.avg.item()
348 logs["val/cur_acc"] = cur_acc_val.avg.item()
349
350 accelerator.log(logs, step=global_step)
351
352 local_progress_bar.clear()
353 global_progress_bar.clear()
354
355 if accelerator.is_main_process:
356 if avg_acc_val.avg.item() > max_acc_val:
357 accelerator.print(
358 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
359 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
360 max_acc_val = avg_acc_val.avg.item()
361
362 # Create the pipeline using using the trained modules and save it.
363 if accelerator.is_main_process:
364 print("Finished!")
365 checkpointer.checkpoint(global_step + global_step_offset, "end")
366 checkpointer.save_samples(global_step + global_step_offset, sample_steps)
367 accelerator.end_training()
368
369 except KeyboardInterrupt:
370 if accelerator.is_main_process:
371 print("Interrupted")
372 checkpointer.checkpoint(global_step + global_step_offset, "end")
373 accelerator.end_training()
374 quit()
diff --git a/training/util.py b/training/util.py
index 60d64f0..0ec2032 100644
--- a/training/util.py
+++ b/training/util.py
@@ -55,8 +55,19 @@ class CheckpointerBase:
55 self.sample_batches = sample_batches 55 self.sample_batches = sample_batches
56 self.sample_batch_size = sample_batch_size 56 self.sample_batch_size = sample_batch_size
57 57
58 @torch.no_grad()
59 def checkpoint(self, step: int, postfix: str):
60 pass
61
58 @torch.inference_mode() 62 @torch.inference_mode()
59 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 63 def save_samples(
64 self,
65 pipeline,
66 step: int,
67 num_inference_steps: int,
68 guidance_scale: float = 7.5,
69 eta: float = 0.0
70 ):
60 samples_path = Path(self.output_dir).joinpath("samples") 71 samples_path = Path(self.output_dir).joinpath("samples")
61 72
62 train_data = self.datamodule.train_dataloader 73 train_data = self.datamodule.train_dataloader