diff options
| -rw-r--r-- | data/csv.py | 36 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 6 | ||||
| -rw-r--r-- | models/clip/prompt.py | 38 | ||||
| -rw-r--r-- | models/clip/util.py | 34 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 | ||||
| -rw-r--r-- | train_dreambooth.py | 7 | ||||
| -rw-r--r-- | train_ti.py | 268 | ||||
| -rw-r--r-- | training/common.py | 205 | ||||
| -rw-r--r-- | training/util.py | 13 |
9 files changed, 334 insertions, 293 deletions
diff --git a/data/csv.py b/data/csv.py index f5fc8e6..a3fef30 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -9,9 +9,10 @@ from PIL import Image | |||
| 9 | 9 | ||
| 10 | from torch.utils.data import IterableDataset, DataLoader, random_split | 10 | from torch.utils.data import IterableDataset, DataLoader, random_split |
| 11 | from torchvision import transforms | 11 | from torchvision import transforms |
| 12 | from transformers import CLIPTokenizer | ||
| 12 | 13 | ||
| 13 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 14 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
| 14 | from models.clip.prompt import PromptProcessor | 15 | from models.clip.util import unify_input_ids |
| 15 | 16 | ||
| 16 | 17 | ||
| 17 | image_cache: dict[str, Image.Image] = {} | 18 | image_cache: dict[str, Image.Image] = {} |
| @@ -102,7 +103,7 @@ def generate_buckets( | |||
| 102 | def collate_fn( | 103 | def collate_fn( |
| 103 | num_class_images: int, | 104 | num_class_images: int, |
| 104 | weight_dtype: torch.dtype, | 105 | weight_dtype: torch.dtype, |
| 105 | prompt_processor: PromptProcessor, | 106 | tokenizer: CLIPTokenizer, |
| 106 | examples | 107 | examples |
| 107 | ): | 108 | ): |
| 108 | prompt_ids = [example["prompt_ids"] for example in examples] | 109 | prompt_ids = [example["prompt_ids"] for example in examples] |
| @@ -119,9 +120,9 @@ def collate_fn( | |||
| 119 | pixel_values = torch.stack(pixel_values) | 120 | pixel_values = torch.stack(pixel_values) |
| 120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 121 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 121 | 122 | ||
| 122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | 123 | prompts = unify_input_ids(tokenizer, prompt_ids) |
| 123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | 124 | nprompts = unify_input_ids(tokenizer, nprompt_ids) |
| 124 | inputs = prompt_processor.unify_input_ids(input_ids) | 125 | inputs = unify_input_ids(tokenizer, input_ids) |
| 125 | 126 | ||
| 126 | batch = { | 127 | batch = { |
| 127 | "prompt_ids": prompts.input_ids, | 128 | "prompt_ids": prompts.input_ids, |
| @@ -148,7 +149,7 @@ class VlpnDataModule(): | |||
| 148 | self, | 149 | self, |
| 149 | batch_size: int, | 150 | batch_size: int, |
| 150 | data_file: str, | 151 | data_file: str, |
| 151 | prompt_processor: PromptProcessor, | 152 | tokenizer: CLIPTokenizer, |
| 152 | class_subdir: str = "cls", | 153 | class_subdir: str = "cls", |
| 153 | num_class_images: int = 1, | 154 | num_class_images: int = 1, |
| 154 | size: int = 768, | 155 | size: int = 768, |
| @@ -179,7 +180,7 @@ class VlpnDataModule(): | |||
| 179 | self.class_root.mkdir(parents=True, exist_ok=True) | 180 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 180 | self.num_class_images = num_class_images | 181 | self.num_class_images = num_class_images |
| 181 | 182 | ||
| 182 | self.prompt_processor = prompt_processor | 183 | self.tokenizer = tokenizer |
| 183 | self.size = size | 184 | self.size = size |
| 184 | self.num_buckets = num_buckets | 185 | self.num_buckets = num_buckets |
| 185 | self.bucket_step_size = bucket_step_size | 186 | self.bucket_step_size = bucket_step_size |
| @@ -272,7 +273,7 @@ class VlpnDataModule(): | |||
| 272 | self.data_val = self.pad_items(data_val) | 273 | self.data_val = self.pad_items(data_val) |
| 273 | 274 | ||
| 274 | train_dataset = VlpnDataset( | 275 | train_dataset = VlpnDataset( |
| 275 | self.data_train, self.prompt_processor, | 276 | self.data_train, self.tokenizer, |
| 276 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 277 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 277 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 278 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 278 | batch_size=self.batch_size, generator=generator, | 279 | batch_size=self.batch_size, generator=generator, |
| @@ -281,7 +282,7 @@ class VlpnDataModule(): | |||
| 281 | ) | 282 | ) |
| 282 | 283 | ||
| 283 | val_dataset = VlpnDataset( | 284 | val_dataset = VlpnDataset( |
| 284 | self.data_val, self.prompt_processor, | 285 | self.data_val, self.tokenizer, |
| 285 | num_buckets=self.num_buckets, progressive_buckets=True, | 286 | num_buckets=self.num_buckets, progressive_buckets=True, |
| 286 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 287 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 287 | repeat=self.valid_set_repeat, | 288 | repeat=self.valid_set_repeat, |
| @@ -289,7 +290,7 @@ class VlpnDataModule(): | |||
| 289 | size=self.size, interpolation=self.interpolation, | 290 | size=self.size, interpolation=self.interpolation, |
| 290 | ) | 291 | ) |
| 291 | 292 | ||
| 292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | 293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) |
| 293 | 294 | ||
| 294 | self.train_dataloader = DataLoader( | 295 | self.train_dataloader = DataLoader( |
| 295 | train_dataset, | 296 | train_dataset, |
| @@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset): | |||
| 306 | def __init__( | 307 | def __init__( |
| 307 | self, | 308 | self, |
| 308 | items: list[VlpnDataItem], | 309 | items: list[VlpnDataItem], |
| 309 | prompt_processor: PromptProcessor, | 310 | tokenizer: CLIPTokenizer, |
| 310 | num_buckets: int = 1, | 311 | num_buckets: int = 1, |
| 311 | bucket_step_size: int = 64, | 312 | bucket_step_size: int = 64, |
| 312 | bucket_max_pixels: Optional[int] = None, | 313 | bucket_max_pixels: Optional[int] = None, |
| @@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset): | |||
| 323 | self.items = items * repeat | 324 | self.items = items * repeat |
| 324 | self.batch_size = batch_size | 325 | self.batch_size = batch_size |
| 325 | 326 | ||
| 326 | self.prompt_processor = prompt_processor | 327 | self.tokenizer = tokenizer |
| 327 | self.num_class_images = num_class_images | 328 | self.num_class_images = num_class_images |
| 328 | self.size = size | 329 | self.size = size |
| 329 | self.dropout = dropout | 330 | self.dropout = dropout |
| @@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset): | |||
| 344 | 345 | ||
| 345 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 346 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
| 346 | 347 | ||
| 348 | def get_input_ids(self, text: str): | ||
| 349 | return self.tokenizer(text, padding="do_not_pad").input_ids | ||
| 350 | |||
| 347 | def __len__(self): | 351 | def __len__(self): |
| 348 | return self.length_ | 352 | return self.length_ |
| 349 | 353 | ||
| @@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset): | |||
| 404 | 408 | ||
| 405 | example = {} | 409 | example = {} |
| 406 | 410 | ||
| 407 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | 411 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
| 408 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | 412 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 409 | 413 | ||
| 410 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 414 | example["instance_prompt_ids"] = self.get_input_ids( |
| 411 | keywords_to_prompt(item.prompt, self.dropout, True) | 415 | keywords_to_prompt(item.prompt, self.dropout, True) |
| 412 | ) | 416 | ) |
| 413 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 417 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 414 | 418 | ||
| 415 | if self.num_class_images != 0: | 419 | if self.num_class_images != 0: |
| 416 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | 420 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
| 417 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 421 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
| 418 | 422 | ||
| 419 | batch.append(example) | 423 | batch.append(example) |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..761efbc 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 40 | self.position_embedding = embeddings.position_embedding | 40 | self.position_embedding = embeddings.position_embedding |
| 41 | self.initializer_factor = config.initializer_factor | 41 | self.initializer_factor = config.initializer_factor |
| 42 | 42 | ||
| 43 | self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() | ||
| 44 | |||
| 43 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.Embedding( |
| 44 | self.token_embedding.num_embeddings, | 46 | self.token_embedding.num_embeddings, |
| 45 | self.token_embedding.embedding_dim, | 47 | self.token_embedding.embedding_dim, |
| @@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 99 | 101 | ||
| 100 | return embeds | 102 | return embeds |
| 101 | 103 | ||
| 102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): | 104 | def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): |
| 105 | if target is None: | ||
| 106 | target = self.decay_target | ||
| 103 | w = self.temp_token_embedding.weight | 107 | w = self.temp_token_embedding.weight |
| 104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 108 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
| 105 | w[self.temp_token_ids] = F.normalize( | 109 | w[self.temp_token_ids] = F.normalize( |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py deleted file mode 100644 index a7380be..0000000 --- a/models/clip/prompt.py +++ /dev/null | |||
| @@ -1,38 +0,0 @@ | |||
| 1 | from typing import Union, Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | |||
| 5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
| 6 | |||
| 7 | |||
| 8 | class PromptProcessor(): | ||
| 9 | def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): | ||
| 10 | self.tokenizer = tokenizer | ||
| 11 | self.text_encoder = text_encoder | ||
| 12 | |||
| 13 | def get_input_ids(self, prompt: Union[str, list[str]]): | ||
| 14 | return self.tokenizer( | ||
| 15 | prompt, | ||
| 16 | padding="do_not_pad", | ||
| 17 | ).input_ids | ||
| 18 | |||
| 19 | def unify_input_ids(self, input_ids: list[list[int]]): | ||
| 20 | return self.tokenizer.pad( | ||
| 21 | {"input_ids": input_ids}, | ||
| 22 | padding=True, | ||
| 23 | pad_to_multiple_of=self.tokenizer.model_max_length, | ||
| 24 | return_tensors="pt" | ||
| 25 | ) | ||
| 26 | |||
| 27 | def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): | ||
| 28 | prompts = input_ids.shape[0] | ||
| 29 | |||
| 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
| 31 | if position_ids is not None: | ||
| 32 | position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
| 33 | if attention_mask is not None: | ||
| 34 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
| 35 | |||
| 36 | text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | ||
| 37 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | ||
| 38 | return text_embeddings | ||
diff --git a/models/clip/util.py b/models/clip/util.py new file mode 100644 index 0000000..8de8c19 --- /dev/null +++ b/models/clip/util.py | |||
| @@ -0,0 +1,34 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | |||
| 5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
| 6 | |||
| 7 | |||
| 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | ||
| 9 | return tokenizer.pad( | ||
| 10 | {"input_ids": input_ids}, | ||
| 11 | padding=True, | ||
| 12 | pad_to_multiple_of=tokenizer.model_max_length, | ||
| 13 | return_tensors="pt" | ||
| 14 | ) | ||
| 15 | |||
| 16 | |||
| 17 | def get_extended_embeddings( | ||
| 18 | text_encoder: CLIPTextModel, | ||
| 19 | input_ids: torch.LongTensor, | ||
| 20 | position_ids: Optional[torch.LongTensor] = None, | ||
| 21 | attention_mask=None | ||
| 22 | ): | ||
| 23 | model_max_length = text_encoder.config.max_position_embeddings | ||
| 24 | prompts = input_ids.shape[0] | ||
| 25 | |||
| 26 | input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
| 27 | if position_ids is not None: | ||
| 28 | position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
| 29 | if attention_mask is not None: | ||
| 30 | attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) | ||
| 31 | |||
| 32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | ||
| 33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | ||
| 34 | return text_embeddings | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 6bc40e9..a5cfc60 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -22,7 +22,7 @@ from diffusers import ( | |||
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.utils import logging, randn_tensor | 23 | from diffusers.utils import logging, randn_tensor |
| 24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
| 25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.util import unify_input_ids, get_extended_embeddings |
| 26 | 26 | ||
| 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 28 | 28 | ||
| @@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 70 | new_config["steps_offset"] = 1 | 70 | new_config["steps_offset"] = 1 |
| 71 | scheduler._internal_dict = FrozenDict(new_config) | 71 | scheduler._internal_dict = FrozenDict(new_config) |
| 72 | 72 | ||
| 73 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 74 | |||
| 75 | self.register_modules( | 73 | self.register_modules( |
| 76 | vae=vae, | 74 | vae=vae, |
| 77 | text_encoder=text_encoder, | 75 | text_encoder=text_encoder, |
| @@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 213 | do_classifier_free_guidance: bool, | 211 | do_classifier_free_guidance: bool, |
| 214 | device | 212 | device |
| 215 | ): | 213 | ): |
| 216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | 214 | if isinstance(prompt[0], str): |
| 215 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | ||
| 216 | else: | ||
| 217 | text_input_ids = prompt | ||
| 218 | |||
| 217 | text_input_ids *= num_images_per_prompt | 219 | text_input_ids *= num_images_per_prompt |
| 218 | 220 | ||
| 219 | if do_classifier_free_guidance: | 221 | if do_classifier_free_guidance: |
| 220 | unconditional_input_ids = self.prompt_processor.get_input_ids( | 222 | if isinstance(prompt[0], str): |
| 221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | 223 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids |
| 224 | else: | ||
| 225 | unconditional_input_ids = negative_prompt | ||
| 222 | unconditional_input_ids *= num_images_per_prompt | 226 | unconditional_input_ids *= num_images_per_prompt |
| 223 | text_input_ids = unconditional_input_ids + text_input_ids | 227 | text_input_ids = unconditional_input_ids + text_input_ids |
| 224 | 228 | ||
| 225 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) | 229 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
| 226 | text_input_ids = text_inputs.input_ids | 230 | text_input_ids = text_inputs.input_ids |
| 227 | 231 | ||
| 228 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 232 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
| @@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 230 | else: | 234 | else: |
| 231 | attention_mask = None | 235 | attention_mask = None |
| 232 | 236 | ||
| 233 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | 237 | text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) |
| 234 | 238 | ||
| 235 | return text_embeddings | 239 | return text_embeddings |
| 236 | 240 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0fe590f..fbbe6c2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -27,7 +27,6 @@ from training.common import loss_step, generate_class_images, get_scheduler | |||
| 27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
| 30 | from models.clip.prompt import PromptProcessor | ||
| 31 | from models.clip.tokenizer import MultiCLIPTokenizer | 30 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 32 | 31 | ||
| 33 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
| @@ -690,8 +689,6 @@ def main(): | |||
| 690 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 689 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 691 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | 690 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
| 692 | 691 | ||
| 693 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 694 | |||
| 695 | if args.scale_lr: | 692 | if args.scale_lr: |
| 696 | args.learning_rate = ( | 693 | args.learning_rate = ( |
| 697 | args.learning_rate * args.gradient_accumulation_steps * | 694 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -751,7 +748,7 @@ def main(): | |||
| 751 | datamodule = VlpnDataModule( | 748 | datamodule = VlpnDataModule( |
| 752 | data_file=args.train_data_file, | 749 | data_file=args.train_data_file, |
| 753 | batch_size=args.train_batch_size, | 750 | batch_size=args.train_batch_size, |
| 754 | prompt_processor=prompt_processor, | 751 | tokenizer=tokenizer, |
| 755 | class_subdir=args.class_image_dir, | 752 | class_subdir=args.class_image_dir, |
| 756 | num_class_images=args.num_class_images, | 753 | num_class_images=args.num_class_images, |
| 757 | size=args.resolution, | 754 | size=args.resolution, |
| @@ -876,7 +873,7 @@ def main(): | |||
| 876 | vae, | 873 | vae, |
| 877 | noise_scheduler, | 874 | noise_scheduler, |
| 878 | unet, | 875 | unet, |
| 879 | prompt_processor, | 876 | text_encoder, |
| 880 | args.num_class_images, | 877 | args.num_class_images, |
| 881 | args.prior_loss_weight, | 878 | args.prior_loss_weight, |
| 882 | args.seed, | 879 | args.seed, |
diff --git a/train_ti.py b/train_ti.py index e18ee38..8c86586 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -21,11 +21,10 @@ from slugify import slugify | |||
| 21 | from util import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
| 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 23 | from data.csv import VlpnDataModule, VlpnDataItem | 23 | from data.csv import VlpnDataModule, VlpnDataItem |
| 24 | from training.common import loss_step, generate_class_images, get_scheduler | 24 | from training.common import loss_step, train_loop, generate_class_images, get_scheduler |
| 25 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
| 26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
| 27 | from models.clip.embeddings import patch_managed_embeddings | 27 | from models.clip.embeddings import patch_managed_embeddings |
| 28 | from models.clip.prompt import PromptProcessor | ||
| 29 | from models.clip.tokenizer import MultiCLIPTokenizer | 28 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 30 | 29 | ||
| 31 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| @@ -198,12 +197,6 @@ def parse_args(): | |||
| 198 | default=100 | 197 | default=100 |
| 199 | ) | 198 | ) |
| 200 | parser.add_argument( | 199 | parser.add_argument( |
| 201 | "--max_train_steps", | ||
| 202 | type=int, | ||
| 203 | default=None, | ||
| 204 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 205 | ) | ||
| 206 | parser.add_argument( | ||
| 207 | "--gradient_accumulation_steps", | 200 | "--gradient_accumulation_steps", |
| 208 | type=int, | 201 | type=int, |
| 209 | default=1, | 202 | default=1, |
| @@ -409,7 +402,7 @@ def parse_args(): | |||
| 409 | ) | 402 | ) |
| 410 | parser.add_argument( | 403 | parser.add_argument( |
| 411 | "--decay_target", | 404 | "--decay_target", |
| 412 | default=0.4, | 405 | default=None, |
| 413 | type=float, | 406 | type=float, |
| 414 | help="Embedding decay target." | 407 | help="Embedding decay target." |
| 415 | ) | 408 | ) |
| @@ -668,8 +661,6 @@ def main(): | |||
| 668 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 661 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 669 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | 662 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
| 670 | 663 | ||
| 671 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 672 | |||
| 673 | if args.scale_lr: | 664 | if args.scale_lr: |
| 674 | args.learning_rate = ( | 665 | args.learning_rate = ( |
| 675 | args.learning_rate * args.gradient_accumulation_steps * | 666 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -722,7 +713,7 @@ def main(): | |||
| 722 | datamodule = VlpnDataModule( | 713 | datamodule = VlpnDataModule( |
| 723 | data_file=args.train_data_file, | 714 | data_file=args.train_data_file, |
| 724 | batch_size=args.train_batch_size, | 715 | batch_size=args.train_batch_size, |
| 725 | prompt_processor=prompt_processor, | 716 | tokenizer=tokenizer, |
| 726 | class_subdir=args.class_image_dir, | 717 | class_subdir=args.class_image_dir, |
| 727 | num_class_images=args.num_class_images, | 718 | num_class_images=args.num_class_images, |
| 728 | size=args.resolution, | 719 | size=args.resolution, |
| @@ -759,13 +750,7 @@ def main(): | |||
| 759 | args.sample_steps | 750 | args.sample_steps |
| 760 | ) | 751 | ) |
| 761 | 752 | ||
| 762 | # Scheduler and math around the number of training steps. | ||
| 763 | overrode_max_train_steps = False | ||
| 764 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 753 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 765 | if args.max_train_steps is None: | ||
| 766 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
| 767 | overrode_max_train_steps = True | ||
| 768 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
| 769 | 754 | ||
| 770 | if args.find_lr: | 755 | if args.find_lr: |
| 771 | lr_scheduler = None | 756 | lr_scheduler = None |
| @@ -781,7 +766,7 @@ def main(): | |||
| 781 | annealing_exp=args.lr_annealing_exp, | 766 | annealing_exp=args.lr_annealing_exp, |
| 782 | cycles=args.lr_cycles, | 767 | cycles=args.lr_cycles, |
| 783 | warmup_epochs=args.lr_warmup_epochs, | 768 | warmup_epochs=args.lr_warmup_epochs, |
| 784 | max_train_steps=args.max_train_steps, | 769 | num_train_epochs=args.num_train_epochs, |
| 785 | num_update_steps_per_epoch=num_update_steps_per_epoch, | 770 | num_update_steps_per_epoch=num_update_steps_per_epoch, |
| 786 | gradient_accumulation_steps=args.gradient_accumulation_steps | 771 | gradient_accumulation_steps=args.gradient_accumulation_steps |
| 787 | ) | 772 | ) |
| @@ -805,15 +790,6 @@ def main(): | |||
| 805 | else: | 790 | else: |
| 806 | unet.eval() | 791 | unet.eval() |
| 807 | 792 | ||
| 808 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
| 809 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
| 810 | if overrode_max_train_steps: | ||
| 811 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
| 812 | |||
| 813 | num_val_steps_per_epoch = len(val_dataloader) | ||
| 814 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
| 815 | val_steps = num_val_steps_per_epoch * num_epochs | ||
| 816 | |||
| 817 | @contextmanager | 793 | @contextmanager |
| 818 | def on_train(): | 794 | def on_train(): |
| 819 | try: | 795 | try: |
| @@ -842,19 +818,44 @@ def main(): | |||
| 842 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) | 818 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) |
| 843 | ) | 819 | ) |
| 844 | 820 | ||
| 821 | if args.use_ema: | ||
| 822 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 823 | |||
| 824 | def on_log(): | ||
| 825 | if args.use_ema: | ||
| 826 | return {"ema_decay": ema_embeddings.decay} | ||
| 827 | return {} | ||
| 828 | |||
| 845 | loop = partial( | 829 | loop = partial( |
| 846 | loss_step, | 830 | loss_step, |
| 847 | vae, | 831 | vae, |
| 848 | noise_scheduler, | 832 | noise_scheduler, |
| 849 | unet, | 833 | unet, |
| 850 | prompt_processor, | 834 | text_encoder, |
| 851 | args.num_class_images, | 835 | args.num_class_images, |
| 852 | args.prior_loss_weight, | 836 | args.prior_loss_weight, |
| 853 | args.seed, | 837 | args.seed, |
| 854 | ) | 838 | ) |
| 855 | 839 | ||
| 856 | # We need to initialize the trackers we use, and also store our configuration. | 840 | checkpointer = Checkpointer( |
| 857 | # The trackers initializes automatically on the main process. | 841 | weight_dtype=weight_dtype, |
| 842 | datamodule=datamodule, | ||
| 843 | accelerator=accelerator, | ||
| 844 | vae=vae, | ||
| 845 | unet=unet, | ||
| 846 | tokenizer=tokenizer, | ||
| 847 | text_encoder=text_encoder, | ||
| 848 | ema_embeddings=ema_embeddings, | ||
| 849 | scheduler=checkpoint_scheduler, | ||
| 850 | placeholder_token=args.placeholder_token, | ||
| 851 | new_ids=new_ids, | ||
| 852 | output_dir=basepath, | ||
| 853 | sample_image_size=args.sample_image_size, | ||
| 854 | sample_batch_size=args.sample_batch_size, | ||
| 855 | sample_batches=args.sample_batches, | ||
| 856 | seed=args.seed | ||
| 857 | ) | ||
| 858 | |||
| 858 | if accelerator.is_main_process: | 859 | if accelerator.is_main_process: |
| 859 | config = vars(args).copy() | 860 | config = vars(args).copy() |
| 860 | config["initializer_token"] = " ".join(config["initializer_token"]) | 861 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| @@ -882,190 +883,27 @@ def main(): | |||
| 882 | 883 | ||
| 883 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 884 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 884 | plt.close() | 885 | plt.close() |
| 885 | 886 | else: | |
| 886 | quit() | 887 | train_loop( |
| 887 | 888 | accelerator=accelerator, | |
| 888 | # Train! | 889 | optimizer=optimizer, |
| 889 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 890 | lr_scheduler=lr_scheduler, |
| 890 | 891 | model=text_encoder, | |
| 891 | logger.info("***** Running training *****") | 892 | checkpointer=checkpointer, |
| 892 | logger.info(f" Num Epochs = {num_epochs}") | 893 | train_dataloader=train_dataloader, |
| 893 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | 894 | val_dataloader=val_dataloader, |
| 894 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | 895 | loss_step=loop, |
| 895 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | 896 | sample_frequency=args.sample_frequency, |
| 896 | logger.info(f" Total optimization steps = {args.max_train_steps}") | 897 | sample_steps=args.sample_steps, |
| 897 | # Only show the progress bar once on each machine. | 898 | checkpoint_frequency=args.checkpoint_frequency, |
| 898 | 899 | global_step_offset=global_step_offset, | |
| 899 | global_step = 0 | 900 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 900 | 901 | num_epochs=args.num_train_epochs, | |
| 901 | avg_loss = AverageMeter() | 902 | on_log=on_log, |
| 902 | avg_acc = AverageMeter() | 903 | on_train=on_train, |
| 903 | 904 | on_after_optimize=on_after_optimize, | |
| 904 | avg_loss_val = AverageMeter() | 905 | on_eval=on_eval |
| 905 | avg_acc_val = AverageMeter() | 906 | ) |
| 906 | |||
| 907 | max_acc_val = 0.0 | ||
| 908 | |||
| 909 | checkpointer = Checkpointer( | ||
| 910 | weight_dtype=weight_dtype, | ||
| 911 | datamodule=datamodule, | ||
| 912 | accelerator=accelerator, | ||
| 913 | vae=vae, | ||
| 914 | unet=unet, | ||
| 915 | tokenizer=tokenizer, | ||
| 916 | text_encoder=text_encoder, | ||
| 917 | ema_embeddings=ema_embeddings, | ||
| 918 | scheduler=checkpoint_scheduler, | ||
| 919 | placeholder_token=args.placeholder_token, | ||
| 920 | new_ids=new_ids, | ||
| 921 | output_dir=basepath, | ||
| 922 | sample_image_size=args.sample_image_size, | ||
| 923 | sample_batch_size=args.sample_batch_size, | ||
| 924 | sample_batches=args.sample_batches, | ||
| 925 | seed=args.seed | ||
| 926 | ) | ||
| 927 | |||
| 928 | local_progress_bar = tqdm( | ||
| 929 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | ||
| 930 | disable=not accelerator.is_local_main_process, | ||
| 931 | dynamic_ncols=True | ||
| 932 | ) | ||
| 933 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | ||
| 934 | |||
| 935 | global_progress_bar = tqdm( | ||
| 936 | range(args.max_train_steps + val_steps), | ||
| 937 | disable=not accelerator.is_local_main_process, | ||
| 938 | dynamic_ncols=True | ||
| 939 | ) | ||
| 940 | global_progress_bar.set_description("Total progress") | ||
| 941 | |||
| 942 | try: | ||
| 943 | for epoch in range(num_epochs): | ||
| 944 | if accelerator.is_main_process: | ||
| 945 | if epoch % args.sample_frequency == 0: | ||
| 946 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
| 947 | |||
| 948 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
| 949 | local_progress_bar.reset() | ||
| 950 | |||
| 951 | text_encoder.train() | ||
| 952 | |||
| 953 | with on_train(): | ||
| 954 | for step, batch in enumerate(train_dataloader): | ||
| 955 | with accelerator.accumulate(text_encoder): | ||
| 956 | loss, acc, bsz = loop(step, batch) | ||
| 957 | |||
| 958 | accelerator.backward(loss) | ||
| 959 | |||
| 960 | optimizer.step() | ||
| 961 | lr_scheduler.step() | ||
| 962 | optimizer.zero_grad(set_to_none=True) | ||
| 963 | |||
| 964 | avg_loss.update(loss.detach_(), bsz) | ||
| 965 | avg_acc.update(acc.detach_(), bsz) | ||
| 966 | |||
| 967 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
| 968 | if accelerator.sync_gradients: | ||
| 969 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 970 | |||
| 971 | if args.use_ema: | ||
| 972 | ema_embeddings.step( | ||
| 973 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 974 | |||
| 975 | local_progress_bar.update(1) | ||
| 976 | global_progress_bar.update(1) | ||
| 977 | |||
| 978 | global_step += 1 | ||
| 979 | |||
| 980 | logs = { | ||
| 981 | "train/loss": avg_loss.avg.item(), | ||
| 982 | "train/acc": avg_acc.avg.item(), | ||
| 983 | "train/cur_loss": loss.item(), | ||
| 984 | "train/cur_acc": acc.item(), | ||
| 985 | "lr": lr_scheduler.get_last_lr()[0], | ||
| 986 | } | ||
| 987 | if args.use_ema: | ||
| 988 | logs["ema_decay"] = ema_embeddings.decay | ||
| 989 | |||
| 990 | accelerator.log(logs, step=global_step) | ||
| 991 | |||
| 992 | local_progress_bar.set_postfix(**logs) | ||
| 993 | |||
| 994 | if global_step >= args.max_train_steps: | ||
| 995 | break | ||
| 996 | |||
| 997 | accelerator.wait_for_everyone() | ||
| 998 | |||
| 999 | text_encoder.eval() | ||
| 1000 | |||
| 1001 | cur_loss_val = AverageMeter() | ||
| 1002 | cur_acc_val = AverageMeter() | ||
| 1003 | |||
| 1004 | with torch.inference_mode(): | ||
| 1005 | with on_eval(): | ||
| 1006 | for step, batch in enumerate(val_dataloader): | ||
| 1007 | loss, acc, bsz = loop(step, batch, True) | ||
| 1008 | |||
| 1009 | loss = loss.detach_() | ||
| 1010 | acc = acc.detach_() | ||
| 1011 | |||
| 1012 | cur_loss_val.update(loss, bsz) | ||
| 1013 | cur_acc_val.update(acc, bsz) | ||
| 1014 | |||
| 1015 | avg_loss_val.update(loss, bsz) | ||
| 1016 | avg_acc_val.update(acc, bsz) | ||
| 1017 | |||
| 1018 | local_progress_bar.update(1) | ||
| 1019 | global_progress_bar.update(1) | ||
| 1020 | |||
| 1021 | logs = { | ||
| 1022 | "val/loss": avg_loss_val.avg.item(), | ||
| 1023 | "val/acc": avg_acc_val.avg.item(), | ||
| 1024 | "val/cur_loss": loss.item(), | ||
| 1025 | "val/cur_acc": acc.item(), | ||
| 1026 | } | ||
| 1027 | local_progress_bar.set_postfix(**logs) | ||
| 1028 | |||
| 1029 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
| 1030 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
| 1031 | |||
| 1032 | accelerator.log(logs, step=global_step) | ||
| 1033 | |||
| 1034 | local_progress_bar.clear() | ||
| 1035 | global_progress_bar.clear() | ||
| 1036 | |||
| 1037 | if accelerator.is_main_process: | ||
| 1038 | if avg_acc_val.avg.item() > max_acc_val: | ||
| 1039 | accelerator.print( | ||
| 1040 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
| 1041 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
| 1042 | max_acc_val = avg_acc_val.avg.item() | ||
| 1043 | |||
| 1044 | if (epoch + 1) % args.checkpoint_frequency == 0: | ||
| 1045 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
| 1046 | save_args(basepath, args, { | ||
| 1047 | "global_step": global_step + global_step_offset | ||
| 1048 | }) | ||
| 1049 | |||
| 1050 | # Create the pipeline using using the trained modules and save it. | ||
| 1051 | if accelerator.is_main_process: | ||
| 1052 | print("Finished! Saving final checkpoint and resume state.") | ||
| 1053 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
| 1054 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
| 1055 | save_args(basepath, args, { | ||
| 1056 | "global_step": global_step + global_step_offset | ||
| 1057 | }) | ||
| 1058 | accelerator.end_training() | ||
| 1059 | |||
| 1060 | except KeyboardInterrupt: | ||
| 1061 | if accelerator.is_main_process: | ||
| 1062 | print("Interrupted, saving checkpoint and resume state...") | ||
| 1063 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
| 1064 | save_args(basepath, args, { | ||
| 1065 | "global_step": global_step + global_step_offset | ||
| 1066 | }) | ||
| 1067 | accelerator.end_training() | ||
| 1068 | quit() | ||
| 1069 | 907 | ||
| 1070 | 908 | ||
| 1071 | if __name__ == "__main__": | 909 | if __name__ == "__main__": |
diff --git a/training/common.py b/training/common.py index 90cf910..842ac07 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -1,14 +1,30 @@ | |||
| 1 | import math | 1 | import math |
| 2 | from contextlib import _GeneratorContextManager, nullcontext | ||
| 3 | from typing import Callable, Any, Tuple, Union | ||
| 2 | 4 | ||
| 3 | import torch | 5 | import torch |
| 4 | import torch.nn.functional as F | 6 | import torch.nn.functional as F |
| 7 | from torch.utils.data import DataLoader | ||
| 5 | 8 | ||
| 9 | from accelerate import Accelerator | ||
| 10 | from transformers import CLIPTokenizer, CLIPTextModel | ||
| 6 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 11 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 7 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 12 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
| 8 | 13 | ||
| 9 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | from tqdm.auto import tqdm |
| 10 | 15 | ||
| 16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 17 | from models.clip.util import get_extended_embeddings | ||
| 11 | from training.optimization import get_one_cycle_schedule | 18 | from training.optimization import get_one_cycle_schedule |
| 19 | from training.util import AverageMeter, CheckpointerBase | ||
| 20 | |||
| 21 | |||
| 22 | def noop(*args, **kwards): | ||
| 23 | pass | ||
| 24 | |||
| 25 | |||
| 26 | def noop_on_log(): | ||
| 27 | return {} | ||
| 12 | 28 | ||
| 13 | 29 | ||
| 14 | def get_scheduler( | 30 | def get_scheduler( |
| @@ -22,10 +38,11 @@ def get_scheduler( | |||
| 22 | cycles: int, | 38 | cycles: int, |
| 23 | warmup_epochs: int, | 39 | warmup_epochs: int, |
| 24 | optimizer: torch.optim.Optimizer, | 40 | optimizer: torch.optim.Optimizer, |
| 25 | max_train_steps: int, | 41 | num_train_epochs: int, |
| 26 | num_update_steps_per_epoch: int, | 42 | num_update_steps_per_epoch: int, |
| 27 | gradient_accumulation_steps: int, | 43 | gradient_accumulation_steps: int, |
| 28 | ): | 44 | ): |
| 45 | num_train_steps = num_train_epochs * num_update_steps_per_epoch | ||
| 29 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | 46 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps |
| 30 | 47 | ||
| 31 | if id == "one_cycle": | 48 | if id == "one_cycle": |
| @@ -33,7 +50,7 @@ def get_scheduler( | |||
| 33 | 50 | ||
| 34 | lr_scheduler = get_one_cycle_schedule( | 51 | lr_scheduler = get_one_cycle_schedule( |
| 35 | optimizer=optimizer, | 52 | optimizer=optimizer, |
| 36 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 53 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
| 37 | warmup=warmup_func, | 54 | warmup=warmup_func, |
| 38 | annealing=annealing_func, | 55 | annealing=annealing_func, |
| 39 | warmup_exp=warmup_exp, | 56 | warmup_exp=warmup_exp, |
| @@ -42,12 +59,12 @@ def get_scheduler( | |||
| 42 | ) | 59 | ) |
| 43 | elif id == "cosine_with_restarts": | 60 | elif id == "cosine_with_restarts": |
| 44 | cycles = cycles if cycles is not None else math.ceil( | 61 | cycles = cycles if cycles is not None else math.ceil( |
| 45 | math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) | 62 | math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) |
| 46 | 63 | ||
| 47 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 64 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 48 | optimizer=optimizer, | 65 | optimizer=optimizer, |
| 49 | num_warmup_steps=warmup_steps, | 66 | num_warmup_steps=warmup_steps, |
| 50 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 67 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
| 51 | num_cycles=cycles, | 68 | num_cycles=cycles, |
| 52 | ) | 69 | ) |
| 53 | else: | 70 | else: |
| @@ -55,7 +72,7 @@ def get_scheduler( | |||
| 55 | id, | 72 | id, |
| 56 | optimizer=optimizer, | 73 | optimizer=optimizer, |
| 57 | num_warmup_steps=warmup_steps, | 74 | num_warmup_steps=warmup_steps, |
| 58 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 75 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
| 59 | ) | 76 | ) |
| 60 | 77 | ||
| 61 | return lr_scheduler | 78 | return lr_scheduler |
| @@ -117,12 +134,12 @@ def loss_step( | |||
| 117 | vae: AutoencoderKL, | 134 | vae: AutoencoderKL, |
| 118 | noise_scheduler: DDPMScheduler, | 135 | noise_scheduler: DDPMScheduler, |
| 119 | unet: UNet2DConditionModel, | 136 | unet: UNet2DConditionModel, |
| 120 | prompt_processor, | 137 | text_encoder: CLIPTextModel, |
| 121 | num_class_images: int, | 138 | num_class_images: int, |
| 122 | prior_loss_weight: float, | 139 | prior_loss_weight: float, |
| 123 | seed: int, | 140 | seed: int, |
| 124 | step: int, | 141 | step: int, |
| 125 | batch, | 142 | batch: dict[str, Any], |
| 126 | eval: bool = False | 143 | eval: bool = False |
| 127 | ): | 144 | ): |
| 128 | # Convert images to latent space | 145 | # Convert images to latent space |
| @@ -149,7 +166,8 @@ def loss_step( | |||
| 149 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 166 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
| 150 | 167 | ||
| 151 | # Get the text embedding for conditioning | 168 | # Get the text embedding for conditioning |
| 152 | encoder_hidden_states = prompt_processor.get_embeddings( | 169 | encoder_hidden_states = get_extended_embeddings( |
| 170 | text_encoder, | ||
| 153 | batch["input_ids"], | 171 | batch["input_ids"], |
| 154 | batch["attention_mask"] | 172 | batch["attention_mask"] |
| 155 | ) | 173 | ) |
| @@ -185,3 +203,172 @@ def loss_step( | |||
| 185 | acc = (model_pred == target).float().mean() | 203 | acc = (model_pred == target).float().mean() |
| 186 | 204 | ||
| 187 | return loss, acc, bsz | 205 | return loss, acc, bsz |
| 206 | |||
| 207 | |||
| 208 | def train_loop( | ||
| 209 | accelerator: Accelerator, | ||
| 210 | optimizer: torch.optim.Optimizer, | ||
| 211 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 212 | model: torch.nn.Module, | ||
| 213 | checkpointer: CheckpointerBase, | ||
| 214 | train_dataloader: DataLoader, | ||
| 215 | val_dataloader: DataLoader, | ||
| 216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | ||
| 217 | sample_frequency: int = 10, | ||
| 218 | sample_steps: int = 20, | ||
| 219 | checkpoint_frequency: int = 50, | ||
| 220 | global_step_offset: int = 0, | ||
| 221 | gradient_accumulation_steps: int = 1, | ||
| 222 | num_epochs: int = 100, | ||
| 223 | on_log: Callable[[], dict[str, Any]] = noop_on_log, | ||
| 224 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | ||
| 225 | on_before_optimize: Callable[[], None] = noop, | ||
| 226 | on_after_optimize: Callable[[float], None] = noop, | ||
| 227 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | ||
| 228 | ): | ||
| 229 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | ||
| 230 | num_train_steps = num_epochs * num_update_steps_per_epoch | ||
| 231 | |||
| 232 | num_val_steps_per_epoch = len(val_dataloader) | ||
| 233 | num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) | ||
| 234 | num_val_steps = num_val_steps_per_epoch * num_epochs | ||
| 235 | |||
| 236 | global_step = 0 | ||
| 237 | |||
| 238 | avg_loss = AverageMeter() | ||
| 239 | avg_acc = AverageMeter() | ||
| 240 | |||
| 241 | avg_loss_val = AverageMeter() | ||
| 242 | avg_acc_val = AverageMeter() | ||
| 243 | |||
| 244 | max_acc_val = 0.0 | ||
| 245 | |||
| 246 | local_progress_bar = tqdm( | ||
| 247 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | ||
| 248 | disable=not accelerator.is_local_main_process, | ||
| 249 | dynamic_ncols=True | ||
| 250 | ) | ||
| 251 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | ||
| 252 | |||
| 253 | global_progress_bar = tqdm( | ||
| 254 | range(num_train_steps + num_val_steps), | ||
| 255 | disable=not accelerator.is_local_main_process, | ||
| 256 | dynamic_ncols=True | ||
| 257 | ) | ||
| 258 | global_progress_bar.set_description("Total progress") | ||
| 259 | |||
| 260 | try: | ||
| 261 | for epoch in range(num_epochs): | ||
| 262 | if accelerator.is_main_process: | ||
| 263 | if epoch % sample_frequency == 0: | ||
| 264 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | ||
| 265 | |||
| 266 | if epoch % checkpoint_frequency == 0 and epoch != 0: | ||
| 267 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
| 268 | |||
| 269 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
| 270 | local_progress_bar.reset() | ||
| 271 | |||
| 272 | model.train() | ||
| 273 | |||
| 274 | with on_train(): | ||
| 275 | for step, batch in enumerate(train_dataloader): | ||
| 276 | with accelerator.accumulate(model): | ||
| 277 | loss, acc, bsz = loss_step(step, batch) | ||
| 278 | |||
| 279 | accelerator.backward(loss) | ||
| 280 | |||
| 281 | on_before_optimize() | ||
| 282 | |||
| 283 | optimizer.step() | ||
| 284 | lr_scheduler.step() | ||
| 285 | optimizer.zero_grad(set_to_none=True) | ||
| 286 | |||
| 287 | avg_loss.update(loss.detach_(), bsz) | ||
| 288 | avg_acc.update(acc.detach_(), bsz) | ||
| 289 | |||
| 290 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
| 291 | if accelerator.sync_gradients: | ||
| 292 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 293 | |||
| 294 | local_progress_bar.update(1) | ||
| 295 | global_progress_bar.update(1) | ||
| 296 | |||
| 297 | global_step += 1 | ||
| 298 | |||
| 299 | logs = { | ||
| 300 | "train/loss": avg_loss.avg.item(), | ||
| 301 | "train/acc": avg_acc.avg.item(), | ||
| 302 | "train/cur_loss": loss.item(), | ||
| 303 | "train/cur_acc": acc.item(), | ||
| 304 | "lr": lr_scheduler.get_last_lr()[0], | ||
| 305 | } | ||
| 306 | logs.update(on_log()) | ||
| 307 | |||
| 308 | accelerator.log(logs, step=global_step) | ||
| 309 | |||
| 310 | local_progress_bar.set_postfix(**logs) | ||
| 311 | |||
| 312 | if global_step >= num_train_steps: | ||
| 313 | break | ||
| 314 | |||
| 315 | accelerator.wait_for_everyone() | ||
| 316 | |||
| 317 | model.eval() | ||
| 318 | |||
| 319 | cur_loss_val = AverageMeter() | ||
| 320 | cur_acc_val = AverageMeter() | ||
| 321 | |||
| 322 | with torch.inference_mode(): | ||
| 323 | with on_eval(): | ||
| 324 | for step, batch in enumerate(val_dataloader): | ||
| 325 | loss, acc, bsz = loss_step(step, batch, True) | ||
| 326 | |||
| 327 | loss = loss.detach_() | ||
| 328 | acc = acc.detach_() | ||
| 329 | |||
| 330 | cur_loss_val.update(loss, bsz) | ||
| 331 | cur_acc_val.update(acc, bsz) | ||
| 332 | |||
| 333 | avg_loss_val.update(loss, bsz) | ||
| 334 | avg_acc_val.update(acc, bsz) | ||
| 335 | |||
| 336 | local_progress_bar.update(1) | ||
| 337 | global_progress_bar.update(1) | ||
| 338 | |||
| 339 | logs = { | ||
| 340 | "val/loss": avg_loss_val.avg.item(), | ||
| 341 | "val/acc": avg_acc_val.avg.item(), | ||
| 342 | "val/cur_loss": loss.item(), | ||
| 343 | "val/cur_acc": acc.item(), | ||
| 344 | } | ||
| 345 | local_progress_bar.set_postfix(**logs) | ||
| 346 | |||
| 347 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
| 348 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
| 349 | |||
| 350 | accelerator.log(logs, step=global_step) | ||
| 351 | |||
| 352 | local_progress_bar.clear() | ||
| 353 | global_progress_bar.clear() | ||
| 354 | |||
| 355 | if accelerator.is_main_process: | ||
| 356 | if avg_acc_val.avg.item() > max_acc_val: | ||
| 357 | accelerator.print( | ||
| 358 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
| 359 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
| 360 | max_acc_val = avg_acc_val.avg.item() | ||
| 361 | |||
| 362 | # Create the pipeline using using the trained modules and save it. | ||
| 363 | if accelerator.is_main_process: | ||
| 364 | print("Finished!") | ||
| 365 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
| 366 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | ||
| 367 | accelerator.end_training() | ||
| 368 | |||
| 369 | except KeyboardInterrupt: | ||
| 370 | if accelerator.is_main_process: | ||
| 371 | print("Interrupted") | ||
| 372 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
| 373 | accelerator.end_training() | ||
| 374 | quit() | ||
diff --git a/training/util.py b/training/util.py index 60d64f0..0ec2032 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -55,8 +55,19 @@ class CheckpointerBase: | |||
| 55 | self.sample_batches = sample_batches | 55 | self.sample_batches = sample_batches |
| 56 | self.sample_batch_size = sample_batch_size | 56 | self.sample_batch_size = sample_batch_size |
| 57 | 57 | ||
| 58 | @torch.no_grad() | ||
| 59 | def checkpoint(self, step: int, postfix: str): | ||
| 60 | pass | ||
| 61 | |||
| 58 | @torch.inference_mode() | 62 | @torch.inference_mode() |
| 59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 63 | def save_samples( |
| 64 | self, | ||
| 65 | pipeline, | ||
| 66 | step: int, | ||
| 67 | num_inference_steps: int, | ||
| 68 | guidance_scale: float = 7.5, | ||
| 69 | eta: float = 0.0 | ||
| 70 | ): | ||
| 60 | samples_path = Path(self.output_dir).joinpath("samples") | 71 | samples_path = Path(self.output_dir).joinpath("samples") |
| 61 | 72 | ||
| 62 | train_data = self.datamodule.train_dataloader | 73 | train_data = self.datamodule.train_dataloader |
