diff options
-rw-r--r-- | data/csv.py | 36 | ||||
-rw-r--r-- | models/clip/embeddings.py | 6 | ||||
-rw-r--r-- | models/clip/prompt.py | 38 | ||||
-rw-r--r-- | models/clip/util.py | 34 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 | ||||
-rw-r--r-- | train_dreambooth.py | 7 | ||||
-rw-r--r-- | train_ti.py | 268 | ||||
-rw-r--r-- | training/common.py | 205 | ||||
-rw-r--r-- | training/util.py | 13 |
9 files changed, 334 insertions, 293 deletions
diff --git a/data/csv.py b/data/csv.py index f5fc8e6..a3fef30 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -9,9 +9,10 @@ from PIL import Image | |||
9 | 9 | ||
10 | from torch.utils.data import IterableDataset, DataLoader, random_split | 10 | from torch.utils.data import IterableDataset, DataLoader, random_split |
11 | from torchvision import transforms | 11 | from torchvision import transforms |
12 | from transformers import CLIPTokenizer | ||
12 | 13 | ||
13 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 14 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
14 | from models.clip.prompt import PromptProcessor | 15 | from models.clip.util import unify_input_ids |
15 | 16 | ||
16 | 17 | ||
17 | image_cache: dict[str, Image.Image] = {} | 18 | image_cache: dict[str, Image.Image] = {} |
@@ -102,7 +103,7 @@ def generate_buckets( | |||
102 | def collate_fn( | 103 | def collate_fn( |
103 | num_class_images: int, | 104 | num_class_images: int, |
104 | weight_dtype: torch.dtype, | 105 | weight_dtype: torch.dtype, |
105 | prompt_processor: PromptProcessor, | 106 | tokenizer: CLIPTokenizer, |
106 | examples | 107 | examples |
107 | ): | 108 | ): |
108 | prompt_ids = [example["prompt_ids"] for example in examples] | 109 | prompt_ids = [example["prompt_ids"] for example in examples] |
@@ -119,9 +120,9 @@ def collate_fn( | |||
119 | pixel_values = torch.stack(pixel_values) | 120 | pixel_values = torch.stack(pixel_values) |
120 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 121 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
121 | 122 | ||
122 | prompts = prompt_processor.unify_input_ids(prompt_ids) | 123 | prompts = unify_input_ids(tokenizer, prompt_ids) |
123 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | 124 | nprompts = unify_input_ids(tokenizer, nprompt_ids) |
124 | inputs = prompt_processor.unify_input_ids(input_ids) | 125 | inputs = unify_input_ids(tokenizer, input_ids) |
125 | 126 | ||
126 | batch = { | 127 | batch = { |
127 | "prompt_ids": prompts.input_ids, | 128 | "prompt_ids": prompts.input_ids, |
@@ -148,7 +149,7 @@ class VlpnDataModule(): | |||
148 | self, | 149 | self, |
149 | batch_size: int, | 150 | batch_size: int, |
150 | data_file: str, | 151 | data_file: str, |
151 | prompt_processor: PromptProcessor, | 152 | tokenizer: CLIPTokenizer, |
152 | class_subdir: str = "cls", | 153 | class_subdir: str = "cls", |
153 | num_class_images: int = 1, | 154 | num_class_images: int = 1, |
154 | size: int = 768, | 155 | size: int = 768, |
@@ -179,7 +180,7 @@ class VlpnDataModule(): | |||
179 | self.class_root.mkdir(parents=True, exist_ok=True) | 180 | self.class_root.mkdir(parents=True, exist_ok=True) |
180 | self.num_class_images = num_class_images | 181 | self.num_class_images = num_class_images |
181 | 182 | ||
182 | self.prompt_processor = prompt_processor | 183 | self.tokenizer = tokenizer |
183 | self.size = size | 184 | self.size = size |
184 | self.num_buckets = num_buckets | 185 | self.num_buckets = num_buckets |
185 | self.bucket_step_size = bucket_step_size | 186 | self.bucket_step_size = bucket_step_size |
@@ -272,7 +273,7 @@ class VlpnDataModule(): | |||
272 | self.data_val = self.pad_items(data_val) | 273 | self.data_val = self.pad_items(data_val) |
273 | 274 | ||
274 | train_dataset = VlpnDataset( | 275 | train_dataset = VlpnDataset( |
275 | self.data_train, self.prompt_processor, | 276 | self.data_train, self.tokenizer, |
276 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 277 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
277 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 278 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
278 | batch_size=self.batch_size, generator=generator, | 279 | batch_size=self.batch_size, generator=generator, |
@@ -281,7 +282,7 @@ class VlpnDataModule(): | |||
281 | ) | 282 | ) |
282 | 283 | ||
283 | val_dataset = VlpnDataset( | 284 | val_dataset = VlpnDataset( |
284 | self.data_val, self.prompt_processor, | 285 | self.data_val, self.tokenizer, |
285 | num_buckets=self.num_buckets, progressive_buckets=True, | 286 | num_buckets=self.num_buckets, progressive_buckets=True, |
286 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 287 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
287 | repeat=self.valid_set_repeat, | 288 | repeat=self.valid_set_repeat, |
@@ -289,7 +290,7 @@ class VlpnDataModule(): | |||
289 | size=self.size, interpolation=self.interpolation, | 290 | size=self.size, interpolation=self.interpolation, |
290 | ) | 291 | ) |
291 | 292 | ||
292 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) | 293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) |
293 | 294 | ||
294 | self.train_dataloader = DataLoader( | 295 | self.train_dataloader = DataLoader( |
295 | train_dataset, | 296 | train_dataset, |
@@ -306,7 +307,7 @@ class VlpnDataset(IterableDataset): | |||
306 | def __init__( | 307 | def __init__( |
307 | self, | 308 | self, |
308 | items: list[VlpnDataItem], | 309 | items: list[VlpnDataItem], |
309 | prompt_processor: PromptProcessor, | 310 | tokenizer: CLIPTokenizer, |
310 | num_buckets: int = 1, | 311 | num_buckets: int = 1, |
311 | bucket_step_size: int = 64, | 312 | bucket_step_size: int = 64, |
312 | bucket_max_pixels: Optional[int] = None, | 313 | bucket_max_pixels: Optional[int] = None, |
@@ -323,7 +324,7 @@ class VlpnDataset(IterableDataset): | |||
323 | self.items = items * repeat | 324 | self.items = items * repeat |
324 | self.batch_size = batch_size | 325 | self.batch_size = batch_size |
325 | 326 | ||
326 | self.prompt_processor = prompt_processor | 327 | self.tokenizer = tokenizer |
327 | self.num_class_images = num_class_images | 328 | self.num_class_images = num_class_images |
328 | self.size = size | 329 | self.size = size |
329 | self.dropout = dropout | 330 | self.dropout = dropout |
@@ -344,6 +345,9 @@ class VlpnDataset(IterableDataset): | |||
344 | 345 | ||
345 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 346 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
346 | 347 | ||
348 | def get_input_ids(self, text: str): | ||
349 | return self.tokenizer(text, padding="do_not_pad").input_ids | ||
350 | |||
347 | def __len__(self): | 351 | def __len__(self): |
348 | return self.length_ | 352 | return self.length_ |
349 | 353 | ||
@@ -404,16 +408,16 @@ class VlpnDataset(IterableDataset): | |||
404 | 408 | ||
405 | example = {} | 409 | example = {} |
406 | 410 | ||
407 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) | 411 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
408 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) | 412 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
409 | 413 | ||
410 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 414 | example["instance_prompt_ids"] = self.get_input_ids( |
411 | keywords_to_prompt(item.prompt, self.dropout, True) | 415 | keywords_to_prompt(item.prompt, self.dropout, True) |
412 | ) | 416 | ) |
413 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 417 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
414 | 418 | ||
415 | if self.num_class_images != 0: | 419 | if self.num_class_images != 0: |
416 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) | 420 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
417 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 421 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
418 | 422 | ||
419 | batch.append(example) | 423 | batch.append(example) |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9a23a2a..761efbc 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -40,6 +40,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
40 | self.position_embedding = embeddings.position_embedding | 40 | self.position_embedding = embeddings.position_embedding |
41 | self.initializer_factor = config.initializer_factor | 41 | self.initializer_factor = config.initializer_factor |
42 | 42 | ||
43 | self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() | ||
44 | |||
43 | self.temp_token_embedding = nn.Embedding( | 45 | self.temp_token_embedding = nn.Embedding( |
44 | self.token_embedding.num_embeddings, | 46 | self.token_embedding.num_embeddings, |
45 | self.token_embedding.embedding_dim, | 47 | self.token_embedding.embedding_dim, |
@@ -99,7 +101,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
99 | 101 | ||
100 | return embeds | 102 | return embeds |
101 | 103 | ||
102 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): | 104 | def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): |
105 | if target is None: | ||
106 | target = self.decay_target | ||
103 | w = self.temp_token_embedding.weight | 107 | w = self.temp_token_embedding.weight |
104 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | 108 | pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) |
105 | w[self.temp_token_ids] = F.normalize( | 109 | w[self.temp_token_ids] = F.normalize( |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py deleted file mode 100644 index a7380be..0000000 --- a/models/clip/prompt.py +++ /dev/null | |||
@@ -1,38 +0,0 @@ | |||
1 | from typing import Union, Optional | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | |||
7 | |||
8 | class PromptProcessor(): | ||
9 | def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): | ||
10 | self.tokenizer = tokenizer | ||
11 | self.text_encoder = text_encoder | ||
12 | |||
13 | def get_input_ids(self, prompt: Union[str, list[str]]): | ||
14 | return self.tokenizer( | ||
15 | prompt, | ||
16 | padding="do_not_pad", | ||
17 | ).input_ids | ||
18 | |||
19 | def unify_input_ids(self, input_ids: list[list[int]]): | ||
20 | return self.tokenizer.pad( | ||
21 | {"input_ids": input_ids}, | ||
22 | padding=True, | ||
23 | pad_to_multiple_of=self.tokenizer.model_max_length, | ||
24 | return_tensors="pt" | ||
25 | ) | ||
26 | |||
27 | def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): | ||
28 | prompts = input_ids.shape[0] | ||
29 | |||
30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
31 | if position_ids is not None: | ||
32 | position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
33 | if attention_mask is not None: | ||
34 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
35 | |||
36 | text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | ||
37 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | ||
38 | return text_embeddings | ||
diff --git a/models/clip/util.py b/models/clip/util.py new file mode 100644 index 0000000..8de8c19 --- /dev/null +++ b/models/clip/util.py | |||
@@ -0,0 +1,34 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | |||
7 | |||
8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | ||
9 | return tokenizer.pad( | ||
10 | {"input_ids": input_ids}, | ||
11 | padding=True, | ||
12 | pad_to_multiple_of=tokenizer.model_max_length, | ||
13 | return_tensors="pt" | ||
14 | ) | ||
15 | |||
16 | |||
17 | def get_extended_embeddings( | ||
18 | text_encoder: CLIPTextModel, | ||
19 | input_ids: torch.LongTensor, | ||
20 | position_ids: Optional[torch.LongTensor] = None, | ||
21 | attention_mask=None | ||
22 | ): | ||
23 | model_max_length = text_encoder.config.max_position_embeddings | ||
24 | prompts = input_ids.shape[0] | ||
25 | |||
26 | input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
27 | if position_ids is not None: | ||
28 | position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
29 | if attention_mask is not None: | ||
30 | attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) | ||
31 | |||
32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | ||
33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | ||
34 | return text_embeddings | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 6bc40e9..a5cfc60 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -22,7 +22,7 @@ from diffusers import ( | |||
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.utils import logging, randn_tensor | 23 | from diffusers.utils import logging, randn_tensor |
24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.util import unify_input_ids, get_extended_embeddings |
26 | 26 | ||
27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
28 | 28 | ||
@@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
70 | new_config["steps_offset"] = 1 | 70 | new_config["steps_offset"] = 1 |
71 | scheduler._internal_dict = FrozenDict(new_config) | 71 | scheduler._internal_dict = FrozenDict(new_config) |
72 | 72 | ||
73 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
74 | |||
75 | self.register_modules( | 73 | self.register_modules( |
76 | vae=vae, | 74 | vae=vae, |
77 | text_encoder=text_encoder, | 75 | text_encoder=text_encoder, |
@@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
213 | do_classifier_free_guidance: bool, | 211 | do_classifier_free_guidance: bool, |
214 | device | 212 | device |
215 | ): | 213 | ): |
216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | 214 | if isinstance(prompt[0], str): |
215 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | ||
216 | else: | ||
217 | text_input_ids = prompt | ||
218 | |||
217 | text_input_ids *= num_images_per_prompt | 219 | text_input_ids *= num_images_per_prompt |
218 | 220 | ||
219 | if do_classifier_free_guidance: | 221 | if do_classifier_free_guidance: |
220 | unconditional_input_ids = self.prompt_processor.get_input_ids( | 222 | if isinstance(prompt[0], str): |
221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | 223 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids |
224 | else: | ||
225 | unconditional_input_ids = negative_prompt | ||
222 | unconditional_input_ids *= num_images_per_prompt | 226 | unconditional_input_ids *= num_images_per_prompt |
223 | text_input_ids = unconditional_input_ids + text_input_ids | 227 | text_input_ids = unconditional_input_ids + text_input_ids |
224 | 228 | ||
225 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) | 229 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
226 | text_input_ids = text_inputs.input_ids | 230 | text_input_ids = text_inputs.input_ids |
227 | 231 | ||
228 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 232 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
@@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
230 | else: | 234 | else: |
231 | attention_mask = None | 235 | attention_mask = None |
232 | 236 | ||
233 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | 237 | text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) |
234 | 238 | ||
235 | return text_embeddings | 239 | return text_embeddings |
236 | 240 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0fe590f..fbbe6c2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -27,7 +27,6 @@ from training.common import loss_step, generate_class_images, get_scheduler | |||
27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
30 | from models.clip.prompt import PromptProcessor | ||
31 | from models.clip.tokenizer import MultiCLIPTokenizer | 30 | from models.clip.tokenizer import MultiCLIPTokenizer |
32 | 31 | ||
33 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
@@ -690,8 +689,6 @@ def main(): | |||
690 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 689 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
691 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | 690 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
692 | 691 | ||
693 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
694 | |||
695 | if args.scale_lr: | 692 | if args.scale_lr: |
696 | args.learning_rate = ( | 693 | args.learning_rate = ( |
697 | args.learning_rate * args.gradient_accumulation_steps * | 694 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -751,7 +748,7 @@ def main(): | |||
751 | datamodule = VlpnDataModule( | 748 | datamodule = VlpnDataModule( |
752 | data_file=args.train_data_file, | 749 | data_file=args.train_data_file, |
753 | batch_size=args.train_batch_size, | 750 | batch_size=args.train_batch_size, |
754 | prompt_processor=prompt_processor, | 751 | tokenizer=tokenizer, |
755 | class_subdir=args.class_image_dir, | 752 | class_subdir=args.class_image_dir, |
756 | num_class_images=args.num_class_images, | 753 | num_class_images=args.num_class_images, |
757 | size=args.resolution, | 754 | size=args.resolution, |
@@ -876,7 +873,7 @@ def main(): | |||
876 | vae, | 873 | vae, |
877 | noise_scheduler, | 874 | noise_scheduler, |
878 | unet, | 875 | unet, |
879 | prompt_processor, | 876 | text_encoder, |
880 | args.num_class_images, | 877 | args.num_class_images, |
881 | args.prior_loss_weight, | 878 | args.prior_loss_weight, |
882 | args.seed, | 879 | args.seed, |
diff --git a/train_ti.py b/train_ti.py index e18ee38..8c86586 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -21,11 +21,10 @@ from slugify import slugify | |||
21 | from util import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
23 | from data.csv import VlpnDataModule, VlpnDataItem | 23 | from data.csv import VlpnDataModule, VlpnDataItem |
24 | from training.common import loss_step, generate_class_images, get_scheduler | 24 | from training.common import loss_step, train_loop, generate_class_images, get_scheduler |
25 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
27 | from models.clip.embeddings import patch_managed_embeddings | 27 | from models.clip.embeddings import patch_managed_embeddings |
28 | from models.clip.prompt import PromptProcessor | ||
29 | from models.clip.tokenizer import MultiCLIPTokenizer | 28 | from models.clip.tokenizer import MultiCLIPTokenizer |
30 | 29 | ||
31 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
@@ -198,12 +197,6 @@ def parse_args(): | |||
198 | default=100 | 197 | default=100 |
199 | ) | 198 | ) |
200 | parser.add_argument( | 199 | parser.add_argument( |
201 | "--max_train_steps", | ||
202 | type=int, | ||
203 | default=None, | ||
204 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
205 | ) | ||
206 | parser.add_argument( | ||
207 | "--gradient_accumulation_steps", | 200 | "--gradient_accumulation_steps", |
208 | type=int, | 201 | type=int, |
209 | default=1, | 202 | default=1, |
@@ -409,7 +402,7 @@ def parse_args(): | |||
409 | ) | 402 | ) |
410 | parser.add_argument( | 403 | parser.add_argument( |
411 | "--decay_target", | 404 | "--decay_target", |
412 | default=0.4, | 405 | default=None, |
413 | type=float, | 406 | type=float, |
414 | help="Embedding decay target." | 407 | help="Embedding decay target." |
415 | ) | 408 | ) |
@@ -668,8 +661,6 @@ def main(): | |||
668 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 661 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
669 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | 662 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
670 | 663 | ||
671 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
672 | |||
673 | if args.scale_lr: | 664 | if args.scale_lr: |
674 | args.learning_rate = ( | 665 | args.learning_rate = ( |
675 | args.learning_rate * args.gradient_accumulation_steps * | 666 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -722,7 +713,7 @@ def main(): | |||
722 | datamodule = VlpnDataModule( | 713 | datamodule = VlpnDataModule( |
723 | data_file=args.train_data_file, | 714 | data_file=args.train_data_file, |
724 | batch_size=args.train_batch_size, | 715 | batch_size=args.train_batch_size, |
725 | prompt_processor=prompt_processor, | 716 | tokenizer=tokenizer, |
726 | class_subdir=args.class_image_dir, | 717 | class_subdir=args.class_image_dir, |
727 | num_class_images=args.num_class_images, | 718 | num_class_images=args.num_class_images, |
728 | size=args.resolution, | 719 | size=args.resolution, |
@@ -759,13 +750,7 @@ def main(): | |||
759 | args.sample_steps | 750 | args.sample_steps |
760 | ) | 751 | ) |
761 | 752 | ||
762 | # Scheduler and math around the number of training steps. | ||
763 | overrode_max_train_steps = False | ||
764 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 753 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
765 | if args.max_train_steps is None: | ||
766 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
767 | overrode_max_train_steps = True | ||
768 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
769 | 754 | ||
770 | if args.find_lr: | 755 | if args.find_lr: |
771 | lr_scheduler = None | 756 | lr_scheduler = None |
@@ -781,7 +766,7 @@ def main(): | |||
781 | annealing_exp=args.lr_annealing_exp, | 766 | annealing_exp=args.lr_annealing_exp, |
782 | cycles=args.lr_cycles, | 767 | cycles=args.lr_cycles, |
783 | warmup_epochs=args.lr_warmup_epochs, | 768 | warmup_epochs=args.lr_warmup_epochs, |
784 | max_train_steps=args.max_train_steps, | 769 | num_train_epochs=args.num_train_epochs, |
785 | num_update_steps_per_epoch=num_update_steps_per_epoch, | 770 | num_update_steps_per_epoch=num_update_steps_per_epoch, |
786 | gradient_accumulation_steps=args.gradient_accumulation_steps | 771 | gradient_accumulation_steps=args.gradient_accumulation_steps |
787 | ) | 772 | ) |
@@ -805,15 +790,6 @@ def main(): | |||
805 | else: | 790 | else: |
806 | unet.eval() | 791 | unet.eval() |
807 | 792 | ||
808 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
809 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
810 | if overrode_max_train_steps: | ||
811 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
812 | |||
813 | num_val_steps_per_epoch = len(val_dataloader) | ||
814 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
815 | val_steps = num_val_steps_per_epoch * num_epochs | ||
816 | |||
817 | @contextmanager | 793 | @contextmanager |
818 | def on_train(): | 794 | def on_train(): |
819 | try: | 795 | try: |
@@ -842,19 +818,44 @@ def main(): | |||
842 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) | 818 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) |
843 | ) | 819 | ) |
844 | 820 | ||
821 | if args.use_ema: | ||
822 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
823 | |||
824 | def on_log(): | ||
825 | if args.use_ema: | ||
826 | return {"ema_decay": ema_embeddings.decay} | ||
827 | return {} | ||
828 | |||
845 | loop = partial( | 829 | loop = partial( |
846 | loss_step, | 830 | loss_step, |
847 | vae, | 831 | vae, |
848 | noise_scheduler, | 832 | noise_scheduler, |
849 | unet, | 833 | unet, |
850 | prompt_processor, | 834 | text_encoder, |
851 | args.num_class_images, | 835 | args.num_class_images, |
852 | args.prior_loss_weight, | 836 | args.prior_loss_weight, |
853 | args.seed, | 837 | args.seed, |
854 | ) | 838 | ) |
855 | 839 | ||
856 | # We need to initialize the trackers we use, and also store our configuration. | 840 | checkpointer = Checkpointer( |
857 | # The trackers initializes automatically on the main process. | 841 | weight_dtype=weight_dtype, |
842 | datamodule=datamodule, | ||
843 | accelerator=accelerator, | ||
844 | vae=vae, | ||
845 | unet=unet, | ||
846 | tokenizer=tokenizer, | ||
847 | text_encoder=text_encoder, | ||
848 | ema_embeddings=ema_embeddings, | ||
849 | scheduler=checkpoint_scheduler, | ||
850 | placeholder_token=args.placeholder_token, | ||
851 | new_ids=new_ids, | ||
852 | output_dir=basepath, | ||
853 | sample_image_size=args.sample_image_size, | ||
854 | sample_batch_size=args.sample_batch_size, | ||
855 | sample_batches=args.sample_batches, | ||
856 | seed=args.seed | ||
857 | ) | ||
858 | |||
858 | if accelerator.is_main_process: | 859 | if accelerator.is_main_process: |
859 | config = vars(args).copy() | 860 | config = vars(args).copy() |
860 | config["initializer_token"] = " ".join(config["initializer_token"]) | 861 | config["initializer_token"] = " ".join(config["initializer_token"]) |
@@ -882,190 +883,27 @@ def main(): | |||
882 | 883 | ||
883 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 884 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
884 | plt.close() | 885 | plt.close() |
885 | 886 | else: | |
886 | quit() | 887 | train_loop( |
887 | 888 | accelerator=accelerator, | |
888 | # Train! | 889 | optimizer=optimizer, |
889 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 890 | lr_scheduler=lr_scheduler, |
890 | 891 | model=text_encoder, | |
891 | logger.info("***** Running training *****") | 892 | checkpointer=checkpointer, |
892 | logger.info(f" Num Epochs = {num_epochs}") | 893 | train_dataloader=train_dataloader, |
893 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | 894 | val_dataloader=val_dataloader, |
894 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | 895 | loss_step=loop, |
895 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | 896 | sample_frequency=args.sample_frequency, |
896 | logger.info(f" Total optimization steps = {args.max_train_steps}") | 897 | sample_steps=args.sample_steps, |
897 | # Only show the progress bar once on each machine. | 898 | checkpoint_frequency=args.checkpoint_frequency, |
898 | 899 | global_step_offset=global_step_offset, | |
899 | global_step = 0 | 900 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
900 | 901 | num_epochs=args.num_train_epochs, | |
901 | avg_loss = AverageMeter() | 902 | on_log=on_log, |
902 | avg_acc = AverageMeter() | 903 | on_train=on_train, |
903 | 904 | on_after_optimize=on_after_optimize, | |
904 | avg_loss_val = AverageMeter() | 905 | on_eval=on_eval |
905 | avg_acc_val = AverageMeter() | 906 | ) |
906 | |||
907 | max_acc_val = 0.0 | ||
908 | |||
909 | checkpointer = Checkpointer( | ||
910 | weight_dtype=weight_dtype, | ||
911 | datamodule=datamodule, | ||
912 | accelerator=accelerator, | ||
913 | vae=vae, | ||
914 | unet=unet, | ||
915 | tokenizer=tokenizer, | ||
916 | text_encoder=text_encoder, | ||
917 | ema_embeddings=ema_embeddings, | ||
918 | scheduler=checkpoint_scheduler, | ||
919 | placeholder_token=args.placeholder_token, | ||
920 | new_ids=new_ids, | ||
921 | output_dir=basepath, | ||
922 | sample_image_size=args.sample_image_size, | ||
923 | sample_batch_size=args.sample_batch_size, | ||
924 | sample_batches=args.sample_batches, | ||
925 | seed=args.seed | ||
926 | ) | ||
927 | |||
928 | local_progress_bar = tqdm( | ||
929 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | ||
930 | disable=not accelerator.is_local_main_process, | ||
931 | dynamic_ncols=True | ||
932 | ) | ||
933 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | ||
934 | |||
935 | global_progress_bar = tqdm( | ||
936 | range(args.max_train_steps + val_steps), | ||
937 | disable=not accelerator.is_local_main_process, | ||
938 | dynamic_ncols=True | ||
939 | ) | ||
940 | global_progress_bar.set_description("Total progress") | ||
941 | |||
942 | try: | ||
943 | for epoch in range(num_epochs): | ||
944 | if accelerator.is_main_process: | ||
945 | if epoch % args.sample_frequency == 0: | ||
946 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
947 | |||
948 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
949 | local_progress_bar.reset() | ||
950 | |||
951 | text_encoder.train() | ||
952 | |||
953 | with on_train(): | ||
954 | for step, batch in enumerate(train_dataloader): | ||
955 | with accelerator.accumulate(text_encoder): | ||
956 | loss, acc, bsz = loop(step, batch) | ||
957 | |||
958 | accelerator.backward(loss) | ||
959 | |||
960 | optimizer.step() | ||
961 | lr_scheduler.step() | ||
962 | optimizer.zero_grad(set_to_none=True) | ||
963 | |||
964 | avg_loss.update(loss.detach_(), bsz) | ||
965 | avg_acc.update(acc.detach_(), bsz) | ||
966 | |||
967 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
968 | if accelerator.sync_gradients: | ||
969 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
970 | |||
971 | if args.use_ema: | ||
972 | ema_embeddings.step( | ||
973 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
974 | |||
975 | local_progress_bar.update(1) | ||
976 | global_progress_bar.update(1) | ||
977 | |||
978 | global_step += 1 | ||
979 | |||
980 | logs = { | ||
981 | "train/loss": avg_loss.avg.item(), | ||
982 | "train/acc": avg_acc.avg.item(), | ||
983 | "train/cur_loss": loss.item(), | ||
984 | "train/cur_acc": acc.item(), | ||
985 | "lr": lr_scheduler.get_last_lr()[0], | ||
986 | } | ||
987 | if args.use_ema: | ||
988 | logs["ema_decay"] = ema_embeddings.decay | ||
989 | |||
990 | accelerator.log(logs, step=global_step) | ||
991 | |||
992 | local_progress_bar.set_postfix(**logs) | ||
993 | |||
994 | if global_step >= args.max_train_steps: | ||
995 | break | ||
996 | |||
997 | accelerator.wait_for_everyone() | ||
998 | |||
999 | text_encoder.eval() | ||
1000 | |||
1001 | cur_loss_val = AverageMeter() | ||
1002 | cur_acc_val = AverageMeter() | ||
1003 | |||
1004 | with torch.inference_mode(): | ||
1005 | with on_eval(): | ||
1006 | for step, batch in enumerate(val_dataloader): | ||
1007 | loss, acc, bsz = loop(step, batch, True) | ||
1008 | |||
1009 | loss = loss.detach_() | ||
1010 | acc = acc.detach_() | ||
1011 | |||
1012 | cur_loss_val.update(loss, bsz) | ||
1013 | cur_acc_val.update(acc, bsz) | ||
1014 | |||
1015 | avg_loss_val.update(loss, bsz) | ||
1016 | avg_acc_val.update(acc, bsz) | ||
1017 | |||
1018 | local_progress_bar.update(1) | ||
1019 | global_progress_bar.update(1) | ||
1020 | |||
1021 | logs = { | ||
1022 | "val/loss": avg_loss_val.avg.item(), | ||
1023 | "val/acc": avg_acc_val.avg.item(), | ||
1024 | "val/cur_loss": loss.item(), | ||
1025 | "val/cur_acc": acc.item(), | ||
1026 | } | ||
1027 | local_progress_bar.set_postfix(**logs) | ||
1028 | |||
1029 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
1030 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
1031 | |||
1032 | accelerator.log(logs, step=global_step) | ||
1033 | |||
1034 | local_progress_bar.clear() | ||
1035 | global_progress_bar.clear() | ||
1036 | |||
1037 | if accelerator.is_main_process: | ||
1038 | if avg_acc_val.avg.item() > max_acc_val: | ||
1039 | accelerator.print( | ||
1040 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
1041 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
1042 | max_acc_val = avg_acc_val.avg.item() | ||
1043 | |||
1044 | if (epoch + 1) % args.checkpoint_frequency == 0: | ||
1045 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
1046 | save_args(basepath, args, { | ||
1047 | "global_step": global_step + global_step_offset | ||
1048 | }) | ||
1049 | |||
1050 | # Create the pipeline using using the trained modules and save it. | ||
1051 | if accelerator.is_main_process: | ||
1052 | print("Finished! Saving final checkpoint and resume state.") | ||
1053 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
1054 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | ||
1055 | save_args(basepath, args, { | ||
1056 | "global_step": global_step + global_step_offset | ||
1057 | }) | ||
1058 | accelerator.end_training() | ||
1059 | |||
1060 | except KeyboardInterrupt: | ||
1061 | if accelerator.is_main_process: | ||
1062 | print("Interrupted, saving checkpoint and resume state...") | ||
1063 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
1064 | save_args(basepath, args, { | ||
1065 | "global_step": global_step + global_step_offset | ||
1066 | }) | ||
1067 | accelerator.end_training() | ||
1068 | quit() | ||
1069 | 907 | ||
1070 | 908 | ||
1071 | if __name__ == "__main__": | 909 | if __name__ == "__main__": |
diff --git a/training/common.py b/training/common.py index 90cf910..842ac07 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -1,14 +1,30 @@ | |||
1 | import math | 1 | import math |
2 | from contextlib import _GeneratorContextManager, nullcontext | ||
3 | from typing import Callable, Any, Tuple, Union | ||
2 | 4 | ||
3 | import torch | 5 | import torch |
4 | import torch.nn.functional as F | 6 | import torch.nn.functional as F |
7 | from torch.utils.data import DataLoader | ||
5 | 8 | ||
9 | from accelerate import Accelerator | ||
10 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 11 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
7 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 12 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
8 | 13 | ||
9 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | from tqdm.auto import tqdm |
10 | 15 | ||
16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
17 | from models.clip.util import get_extended_embeddings | ||
11 | from training.optimization import get_one_cycle_schedule | 18 | from training.optimization import get_one_cycle_schedule |
19 | from training.util import AverageMeter, CheckpointerBase | ||
20 | |||
21 | |||
22 | def noop(*args, **kwards): | ||
23 | pass | ||
24 | |||
25 | |||
26 | def noop_on_log(): | ||
27 | return {} | ||
12 | 28 | ||
13 | 29 | ||
14 | def get_scheduler( | 30 | def get_scheduler( |
@@ -22,10 +38,11 @@ def get_scheduler( | |||
22 | cycles: int, | 38 | cycles: int, |
23 | warmup_epochs: int, | 39 | warmup_epochs: int, |
24 | optimizer: torch.optim.Optimizer, | 40 | optimizer: torch.optim.Optimizer, |
25 | max_train_steps: int, | 41 | num_train_epochs: int, |
26 | num_update_steps_per_epoch: int, | 42 | num_update_steps_per_epoch: int, |
27 | gradient_accumulation_steps: int, | 43 | gradient_accumulation_steps: int, |
28 | ): | 44 | ): |
45 | num_train_steps = num_train_epochs * num_update_steps_per_epoch | ||
29 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | 46 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps |
30 | 47 | ||
31 | if id == "one_cycle": | 48 | if id == "one_cycle": |
@@ -33,7 +50,7 @@ def get_scheduler( | |||
33 | 50 | ||
34 | lr_scheduler = get_one_cycle_schedule( | 51 | lr_scheduler = get_one_cycle_schedule( |
35 | optimizer=optimizer, | 52 | optimizer=optimizer, |
36 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 53 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
37 | warmup=warmup_func, | 54 | warmup=warmup_func, |
38 | annealing=annealing_func, | 55 | annealing=annealing_func, |
39 | warmup_exp=warmup_exp, | 56 | warmup_exp=warmup_exp, |
@@ -42,12 +59,12 @@ def get_scheduler( | |||
42 | ) | 59 | ) |
43 | elif id == "cosine_with_restarts": | 60 | elif id == "cosine_with_restarts": |
44 | cycles = cycles if cycles is not None else math.ceil( | 61 | cycles = cycles if cycles is not None else math.ceil( |
45 | math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) | 62 | math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) |
46 | 63 | ||
47 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 64 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
48 | optimizer=optimizer, | 65 | optimizer=optimizer, |
49 | num_warmup_steps=warmup_steps, | 66 | num_warmup_steps=warmup_steps, |
50 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 67 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
51 | num_cycles=cycles, | 68 | num_cycles=cycles, |
52 | ) | 69 | ) |
53 | else: | 70 | else: |
@@ -55,7 +72,7 @@ def get_scheduler( | |||
55 | id, | 72 | id, |
56 | optimizer=optimizer, | 73 | optimizer=optimizer, |
57 | num_warmup_steps=warmup_steps, | 74 | num_warmup_steps=warmup_steps, |
58 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 75 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
59 | ) | 76 | ) |
60 | 77 | ||
61 | return lr_scheduler | 78 | return lr_scheduler |
@@ -117,12 +134,12 @@ def loss_step( | |||
117 | vae: AutoencoderKL, | 134 | vae: AutoencoderKL, |
118 | noise_scheduler: DDPMScheduler, | 135 | noise_scheduler: DDPMScheduler, |
119 | unet: UNet2DConditionModel, | 136 | unet: UNet2DConditionModel, |
120 | prompt_processor, | 137 | text_encoder: CLIPTextModel, |
121 | num_class_images: int, | 138 | num_class_images: int, |
122 | prior_loss_weight: float, | 139 | prior_loss_weight: float, |
123 | seed: int, | 140 | seed: int, |
124 | step: int, | 141 | step: int, |
125 | batch, | 142 | batch: dict[str, Any], |
126 | eval: bool = False | 143 | eval: bool = False |
127 | ): | 144 | ): |
128 | # Convert images to latent space | 145 | # Convert images to latent space |
@@ -149,7 +166,8 @@ def loss_step( | |||
149 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 166 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
150 | 167 | ||
151 | # Get the text embedding for conditioning | 168 | # Get the text embedding for conditioning |
152 | encoder_hidden_states = prompt_processor.get_embeddings( | 169 | encoder_hidden_states = get_extended_embeddings( |
170 | text_encoder, | ||
153 | batch["input_ids"], | 171 | batch["input_ids"], |
154 | batch["attention_mask"] | 172 | batch["attention_mask"] |
155 | ) | 173 | ) |
@@ -185,3 +203,172 @@ def loss_step( | |||
185 | acc = (model_pred == target).float().mean() | 203 | acc = (model_pred == target).float().mean() |
186 | 204 | ||
187 | return loss, acc, bsz | 205 | return loss, acc, bsz |
206 | |||
207 | |||
208 | def train_loop( | ||
209 | accelerator: Accelerator, | ||
210 | optimizer: torch.optim.Optimizer, | ||
211 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
212 | model: torch.nn.Module, | ||
213 | checkpointer: CheckpointerBase, | ||
214 | train_dataloader: DataLoader, | ||
215 | val_dataloader: DataLoader, | ||
216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | ||
217 | sample_frequency: int = 10, | ||
218 | sample_steps: int = 20, | ||
219 | checkpoint_frequency: int = 50, | ||
220 | global_step_offset: int = 0, | ||
221 | gradient_accumulation_steps: int = 1, | ||
222 | num_epochs: int = 100, | ||
223 | on_log: Callable[[], dict[str, Any]] = noop_on_log, | ||
224 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | ||
225 | on_before_optimize: Callable[[], None] = noop, | ||
226 | on_after_optimize: Callable[[float], None] = noop, | ||
227 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | ||
228 | ): | ||
229 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | ||
230 | num_train_steps = num_epochs * num_update_steps_per_epoch | ||
231 | |||
232 | num_val_steps_per_epoch = len(val_dataloader) | ||
233 | num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) | ||
234 | num_val_steps = num_val_steps_per_epoch * num_epochs | ||
235 | |||
236 | global_step = 0 | ||
237 | |||
238 | avg_loss = AverageMeter() | ||
239 | avg_acc = AverageMeter() | ||
240 | |||
241 | avg_loss_val = AverageMeter() | ||
242 | avg_acc_val = AverageMeter() | ||
243 | |||
244 | max_acc_val = 0.0 | ||
245 | |||
246 | local_progress_bar = tqdm( | ||
247 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | ||
248 | disable=not accelerator.is_local_main_process, | ||
249 | dynamic_ncols=True | ||
250 | ) | ||
251 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | ||
252 | |||
253 | global_progress_bar = tqdm( | ||
254 | range(num_train_steps + num_val_steps), | ||
255 | disable=not accelerator.is_local_main_process, | ||
256 | dynamic_ncols=True | ||
257 | ) | ||
258 | global_progress_bar.set_description("Total progress") | ||
259 | |||
260 | try: | ||
261 | for epoch in range(num_epochs): | ||
262 | if accelerator.is_main_process: | ||
263 | if epoch % sample_frequency == 0: | ||
264 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | ||
265 | |||
266 | if epoch % checkpoint_frequency == 0 and epoch != 0: | ||
267 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
268 | |||
269 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
270 | local_progress_bar.reset() | ||
271 | |||
272 | model.train() | ||
273 | |||
274 | with on_train(): | ||
275 | for step, batch in enumerate(train_dataloader): | ||
276 | with accelerator.accumulate(model): | ||
277 | loss, acc, bsz = loss_step(step, batch) | ||
278 | |||
279 | accelerator.backward(loss) | ||
280 | |||
281 | on_before_optimize() | ||
282 | |||
283 | optimizer.step() | ||
284 | lr_scheduler.step() | ||
285 | optimizer.zero_grad(set_to_none=True) | ||
286 | |||
287 | avg_loss.update(loss.detach_(), bsz) | ||
288 | avg_acc.update(acc.detach_(), bsz) | ||
289 | |||
290 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
291 | if accelerator.sync_gradients: | ||
292 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
293 | |||
294 | local_progress_bar.update(1) | ||
295 | global_progress_bar.update(1) | ||
296 | |||
297 | global_step += 1 | ||
298 | |||
299 | logs = { | ||
300 | "train/loss": avg_loss.avg.item(), | ||
301 | "train/acc": avg_acc.avg.item(), | ||
302 | "train/cur_loss": loss.item(), | ||
303 | "train/cur_acc": acc.item(), | ||
304 | "lr": lr_scheduler.get_last_lr()[0], | ||
305 | } | ||
306 | logs.update(on_log()) | ||
307 | |||
308 | accelerator.log(logs, step=global_step) | ||
309 | |||
310 | local_progress_bar.set_postfix(**logs) | ||
311 | |||
312 | if global_step >= num_train_steps: | ||
313 | break | ||
314 | |||
315 | accelerator.wait_for_everyone() | ||
316 | |||
317 | model.eval() | ||
318 | |||
319 | cur_loss_val = AverageMeter() | ||
320 | cur_acc_val = AverageMeter() | ||
321 | |||
322 | with torch.inference_mode(): | ||
323 | with on_eval(): | ||
324 | for step, batch in enumerate(val_dataloader): | ||
325 | loss, acc, bsz = loss_step(step, batch, True) | ||
326 | |||
327 | loss = loss.detach_() | ||
328 | acc = acc.detach_() | ||
329 | |||
330 | cur_loss_val.update(loss, bsz) | ||
331 | cur_acc_val.update(acc, bsz) | ||
332 | |||
333 | avg_loss_val.update(loss, bsz) | ||
334 | avg_acc_val.update(acc, bsz) | ||
335 | |||
336 | local_progress_bar.update(1) | ||
337 | global_progress_bar.update(1) | ||
338 | |||
339 | logs = { | ||
340 | "val/loss": avg_loss_val.avg.item(), | ||
341 | "val/acc": avg_acc_val.avg.item(), | ||
342 | "val/cur_loss": loss.item(), | ||
343 | "val/cur_acc": acc.item(), | ||
344 | } | ||
345 | local_progress_bar.set_postfix(**logs) | ||
346 | |||
347 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
348 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
349 | |||
350 | accelerator.log(logs, step=global_step) | ||
351 | |||
352 | local_progress_bar.clear() | ||
353 | global_progress_bar.clear() | ||
354 | |||
355 | if accelerator.is_main_process: | ||
356 | if avg_acc_val.avg.item() > max_acc_val: | ||
357 | accelerator.print( | ||
358 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
359 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
360 | max_acc_val = avg_acc_val.avg.item() | ||
361 | |||
362 | # Create the pipeline using using the trained modules and save it. | ||
363 | if accelerator.is_main_process: | ||
364 | print("Finished!") | ||
365 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
366 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | ||
367 | accelerator.end_training() | ||
368 | |||
369 | except KeyboardInterrupt: | ||
370 | if accelerator.is_main_process: | ||
371 | print("Interrupted") | ||
372 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
373 | accelerator.end_training() | ||
374 | quit() | ||
diff --git a/training/util.py b/training/util.py index 60d64f0..0ec2032 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -55,8 +55,19 @@ class CheckpointerBase: | |||
55 | self.sample_batches = sample_batches | 55 | self.sample_batches = sample_batches |
56 | self.sample_batch_size = sample_batch_size | 56 | self.sample_batch_size = sample_batch_size |
57 | 57 | ||
58 | @torch.no_grad() | ||
59 | def checkpoint(self, step: int, postfix: str): | ||
60 | pass | ||
61 | |||
58 | @torch.inference_mode() | 62 | @torch.inference_mode() |
59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 63 | def save_samples( |
64 | self, | ||
65 | pipeline, | ||
66 | step: int, | ||
67 | num_inference_steps: int, | ||
68 | guidance_scale: float = 7.5, | ||
69 | eta: float = 0.0 | ||
70 | ): | ||
60 | samples_path = Path(self.output_dir).joinpath("samples") | 71 | samples_path = Path(self.output_dir).joinpath("samples") |
61 | 72 | ||
62 | train_data = self.datamodule.train_dataloader | 73 | train_data = self.datamodule.train_dataloader |