diff options
-rw-r--r-- | train_dreambooth.py | 126 | ||||
-rw-r--r-- | train_ti.py | 175 | ||||
-rw-r--r-- | training/optimization.py | 2 | ||||
-rw-r--r-- | training/util.py | 131 |
4 files changed, 203 insertions, 231 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0f8fece..9749c62 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed | |||
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
19 | from PIL import Image | ||
20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
21 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
22 | from slugify import slugify | 21 | from slugify import slugify |
@@ -25,6 +24,7 @@ from common import load_text_embeddings | |||
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
27 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
28 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
29 | 29 | ||
30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
@@ -385,41 +385,7 @@ def parse_args(): | |||
385 | return args | 385 | return args |
386 | 386 | ||
387 | 387 | ||
388 | def save_args(basepath: Path, args, extra={}): | 388 | class Checkpointer(CheckpointerBase): |
389 | info = {"args": vars(args)} | ||
390 | info["args"].update(extra) | ||
391 | with open(basepath.joinpath("args.json"), "w") as f: | ||
392 | json.dump(info, f, indent=4) | ||
393 | |||
394 | |||
395 | def freeze_params(params): | ||
396 | for param in params: | ||
397 | param.requires_grad = False | ||
398 | |||
399 | |||
400 | def make_grid(images, rows, cols): | ||
401 | w, h = images[0].size | ||
402 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
403 | for i, image in enumerate(images): | ||
404 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
405 | return grid | ||
406 | |||
407 | |||
408 | class AverageMeter: | ||
409 | def __init__(self, name=None): | ||
410 | self.name = name | ||
411 | self.reset() | ||
412 | |||
413 | def reset(self): | ||
414 | self.sum = self.count = self.avg = 0 | ||
415 | |||
416 | def update(self, val, n=1): | ||
417 | self.sum += val * n | ||
418 | self.count += n | ||
419 | self.avg = self.sum / self.count | ||
420 | |||
421 | |||
422 | class Checkpointer: | ||
423 | def __init__( | 389 | def __init__( |
424 | self, | 390 | self, |
425 | datamodule, | 391 | datamodule, |
@@ -437,9 +403,20 @@ class Checkpointer: | |||
437 | sample_image_size, | 403 | sample_image_size, |
438 | sample_batches, | 404 | sample_batches, |
439 | sample_batch_size, | 405 | sample_batch_size, |
440 | seed | 406 | seed, |
441 | ): | 407 | ): |
442 | self.datamodule = datamodule | 408 | super().__init__( |
409 | datamodule=datamodule, | ||
410 | output_dir=output_dir, | ||
411 | instance_identifier=instance_identifier, | ||
412 | placeholder_token=placeholder_token, | ||
413 | placeholder_token_id=placeholder_token_id, | ||
414 | sample_image_size=sample_image_size, | ||
415 | seed=seed or torch.random.seed(), | ||
416 | sample_batches=sample_batches, | ||
417 | sample_batch_size=sample_batch_size | ||
418 | ) | ||
419 | |||
443 | self.accelerator = accelerator | 420 | self.accelerator = accelerator |
444 | self.vae = vae | 421 | self.vae = vae |
445 | self.unet = unet | 422 | self.unet = unet |
@@ -447,14 +424,6 @@ class Checkpointer: | |||
447 | self.tokenizer = tokenizer | 424 | self.tokenizer = tokenizer |
448 | self.text_encoder = text_encoder | 425 | self.text_encoder = text_encoder |
449 | self.scheduler = scheduler | 426 | self.scheduler = scheduler |
450 | self.output_dir = output_dir | ||
451 | self.instance_identifier = instance_identifier | ||
452 | self.placeholder_token = placeholder_token | ||
453 | self.placeholder_token_id = placeholder_token_id | ||
454 | self.sample_image_size = sample_image_size | ||
455 | self.seed = seed or torch.random.seed() | ||
456 | self.sample_batches = sample_batches | ||
457 | self.sample_batch_size = sample_batch_size | ||
458 | 427 | ||
459 | @torch.no_grad() | 428 | @torch.no_grad() |
460 | def save_model(self): | 429 | def save_model(self): |
@@ -481,8 +450,6 @@ class Checkpointer: | |||
481 | 450 | ||
482 | @torch.no_grad() | 451 | @torch.no_grad() |
483 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 452 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
484 | samples_path = Path(self.output_dir).joinpath("samples") | ||
485 | |||
486 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 453 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 454 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
488 | 455 | ||
@@ -495,72 +462,11 @@ class Checkpointer: | |||
495 | ).to(self.accelerator.device) | 462 | ).to(self.accelerator.device) |
496 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 463 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
497 | 464 | ||
498 | train_data = self.datamodule.train_dataloader() | 465 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
499 | val_data = self.datamodule.val_dataloader() | ||
500 | |||
501 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
502 | stable_latents = torch.randn( | ||
503 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
504 | device=pipeline.device, | ||
505 | generator=generator, | ||
506 | ) | ||
507 | |||
508 | with torch.autocast("cuda"), torch.inference_mode(): | ||
509 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
510 | all_samples = [] | ||
511 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
512 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
513 | |||
514 | data_enum = enumerate(data) | ||
515 | |||
516 | batches = [ | ||
517 | batch | ||
518 | for j, batch in data_enum | ||
519 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
520 | ] | ||
521 | prompts = [ | ||
522 | prompt.format(identifier=self.instance_identifier) | ||
523 | for batch in batches | ||
524 | for prompt in batch["prompts"] | ||
525 | ] | ||
526 | nprompts = [ | ||
527 | prompt | ||
528 | for batch in batches | ||
529 | for prompt in batch["nprompts"] | ||
530 | ] | ||
531 | |||
532 | for i in range(self.sample_batches): | ||
533 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
534 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
535 | |||
536 | samples = pipeline( | ||
537 | prompt=prompt, | ||
538 | negative_prompt=nprompt, | ||
539 | height=self.sample_image_size, | ||
540 | width=self.sample_image_size, | ||
541 | image=latents[:len(prompt)] if latents is not None else None, | ||
542 | generator=generator if latents is not None else None, | ||
543 | guidance_scale=guidance_scale, | ||
544 | eta=eta, | ||
545 | num_inference_steps=num_inference_steps, | ||
546 | output_type='pil' | ||
547 | ).images | ||
548 | |||
549 | all_samples += samples | ||
550 | |||
551 | del samples | ||
552 | |||
553 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
554 | image_grid.save(file_path, quality=85) | ||
555 | |||
556 | del all_samples | ||
557 | del image_grid | ||
558 | 466 | ||
559 | del unet | 467 | del unet |
560 | del text_encoder | 468 | del text_encoder |
561 | del pipeline | 469 | del pipeline |
562 | del generator | ||
563 | del stable_latents | ||
564 | 470 | ||
565 | if torch.cuda.is_available(): | 471 | if torch.cuda.is_available(): |
566 | torch.cuda.empty_cache() | 472 | torch.cuda.empty_cache() |
diff --git a/train_ti.py b/train_ti.py index 9616db6..198cf37 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -7,7 +7,6 @@ import logging | |||
7 | import json | 7 | import json |
8 | from pathlib import Path | 8 | from pathlib import Path |
9 | 9 | ||
10 | import numpy as np | ||
11 | import torch | 10 | import torch |
12 | import torch.nn.functional as F | 11 | import torch.nn.functional as F |
13 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
@@ -17,7 +16,6 @@ from accelerate.logging import get_logger | |||
17 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 18 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from PIL import Image | ||
21 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
22 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 21 | from slugify import slugify |
@@ -26,6 +24,7 @@ from common import load_text_embeddings, load_text_embedding | |||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
28 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
30 | 29 | ||
31 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
@@ -138,7 +137,7 @@ def parse_args(): | |||
138 | parser.add_argument( | 137 | parser.add_argument( |
139 | "--tag_dropout", | 138 | "--tag_dropout", |
140 | type=float, | 139 | type=float, |
141 | default=0, | 140 | default=0.1, |
142 | help="Tag dropout probability.", | 141 | help="Tag dropout probability.", |
143 | ) | 142 | ) |
144 | parser.add_argument( | 143 | parser.add_argument( |
@@ -355,27 +354,7 @@ def parse_args(): | |||
355 | return args | 354 | return args |
356 | 355 | ||
357 | 356 | ||
358 | def freeze_params(params): | 357 | class Checkpointer(CheckpointerBase): |
359 | for param in params: | ||
360 | param.requires_grad = False | ||
361 | |||
362 | |||
363 | def save_args(basepath: Path, args, extra={}): | ||
364 | info = {"args": vars(args)} | ||
365 | info["args"].update(extra) | ||
366 | with open(basepath.joinpath("args.json"), "w") as f: | ||
367 | json.dump(info, f, indent=4) | ||
368 | |||
369 | |||
370 | def make_grid(images, rows, cols): | ||
371 | w, h = images[0].size | ||
372 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
373 | for i, image in enumerate(images): | ||
374 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
375 | return grid | ||
376 | |||
377 | |||
378 | class Checkpointer: | ||
379 | def __init__( | 358 | def __init__( |
380 | self, | 359 | self, |
381 | datamodule, | 360 | datamodule, |
@@ -394,21 +373,24 @@ class Checkpointer: | |||
394 | sample_batch_size, | 373 | sample_batch_size, |
395 | seed | 374 | seed |
396 | ): | 375 | ): |
397 | self.datamodule = datamodule | 376 | super().__init__( |
377 | datamodule=datamodule, | ||
378 | output_dir=output_dir, | ||
379 | instance_identifier=instance_identifier, | ||
380 | placeholder_token=placeholder_token, | ||
381 | placeholder_token_id=placeholder_token_id, | ||
382 | sample_image_size=sample_image_size, | ||
383 | seed=seed or torch.random.seed(), | ||
384 | sample_batches=sample_batches, | ||
385 | sample_batch_size=sample_batch_size | ||
386 | ) | ||
387 | |||
398 | self.accelerator = accelerator | 388 | self.accelerator = accelerator |
399 | self.vae = vae | 389 | self.vae = vae |
400 | self.unet = unet | 390 | self.unet = unet |
401 | self.tokenizer = tokenizer | 391 | self.tokenizer = tokenizer |
402 | self.text_encoder = text_encoder | 392 | self.text_encoder = text_encoder |
403 | self.scheduler = scheduler | 393 | self.scheduler = scheduler |
404 | self.instance_identifier = instance_identifier | ||
405 | self.placeholder_token = placeholder_token | ||
406 | self.placeholder_token_id = placeholder_token_id | ||
407 | self.output_dir = output_dir | ||
408 | self.sample_image_size = sample_image_size | ||
409 | self.seed = seed or torch.random.seed() | ||
410 | self.sample_batches = sample_batches | ||
411 | self.sample_batch_size = sample_batch_size | ||
412 | 394 | ||
413 | @torch.no_grad() | 395 | @torch.no_grad() |
414 | def checkpoint(self, step, postfix): | 396 | def checkpoint(self, step, postfix): |
@@ -431,9 +413,7 @@ class Checkpointer: | |||
431 | del learned_embeds | 413 | del learned_embeds |
432 | 414 | ||
433 | @torch.no_grad() | 415 | @torch.no_grad() |
434 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 416 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
435 | samples_path = Path(self.output_dir).joinpath("samples") | ||
436 | |||
437 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 417 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
438 | 418 | ||
439 | # Save a sample image | 419 | # Save a sample image |
@@ -446,71 +426,10 @@ class Checkpointer: | |||
446 | ).to(self.accelerator.device) | 426 | ).to(self.accelerator.device) |
447 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 427 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
448 | 428 | ||
449 | train_data = self.datamodule.train_dataloader() | 429 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
450 | val_data = self.datamodule.val_dataloader() | ||
451 | |||
452 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
453 | stable_latents = torch.randn( | ||
454 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | ||
455 | device=pipeline.device, | ||
456 | generator=generator, | ||
457 | ) | ||
458 | |||
459 | with torch.autocast("cuda"), torch.inference_mode(): | ||
460 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
461 | all_samples = [] | ||
462 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
463 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
464 | |||
465 | data_enum = enumerate(data) | ||
466 | |||
467 | batches = [ | ||
468 | batch | ||
469 | for j, batch in data_enum | ||
470 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
471 | ] | ||
472 | prompts = [ | ||
473 | prompt.format(identifier=self.instance_identifier) | ||
474 | for batch in batches | ||
475 | for prompt in batch["prompts"] | ||
476 | ] | ||
477 | nprompts = [ | ||
478 | prompt | ||
479 | for batch in batches | ||
480 | for prompt in batch["nprompts"] | ||
481 | ] | ||
482 | |||
483 | for i in range(self.sample_batches): | ||
484 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
485 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
486 | |||
487 | samples = pipeline( | ||
488 | prompt=prompt, | ||
489 | negative_prompt=nprompt, | ||
490 | height=self.sample_image_size, | ||
491 | width=self.sample_image_size, | ||
492 | image=latents[:len(prompt)] if latents is not None else None, | ||
493 | generator=generator if latents is not None else None, | ||
494 | guidance_scale=guidance_scale, | ||
495 | eta=eta, | ||
496 | num_inference_steps=num_inference_steps, | ||
497 | output_type='pil' | ||
498 | ).images | ||
499 | |||
500 | all_samples += samples | ||
501 | |||
502 | del samples | ||
503 | |||
504 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
505 | image_grid.save(file_path, quality=85) | ||
506 | |||
507 | del all_samples | ||
508 | del image_grid | ||
509 | 430 | ||
510 | del text_encoder | 431 | del text_encoder |
511 | del pipeline | 432 | del pipeline |
512 | del generator | ||
513 | del stable_latents | ||
514 | 433 | ||
515 | if torch.cuda.is_available(): | 434 | if torch.cuda.is_available(): |
516 | torch.cuda.empty_cache() | 435 | torch.cuda.empty_cache() |
@@ -814,7 +733,14 @@ def main(): | |||
814 | # Only show the progress bar once on each machine. | 733 | # Only show the progress bar once on each machine. |
815 | 734 | ||
816 | global_step = 0 | 735 | global_step = 0 |
817 | min_val_loss = np.inf | 736 | |
737 | avg_loss = AverageMeter() | ||
738 | avg_acc = AverageMeter() | ||
739 | |||
740 | avg_loss_val = AverageMeter() | ||
741 | avg_acc_val = AverageMeter() | ||
742 | |||
743 | max_acc_val = 0.0 | ||
818 | 744 | ||
819 | checkpointer = Checkpointer( | 745 | checkpointer = Checkpointer( |
820 | datamodule=datamodule, | 746 | datamodule=datamodule, |
@@ -835,9 +761,7 @@ def main(): | |||
835 | ) | 761 | ) |
836 | 762 | ||
837 | if accelerator.is_main_process: | 763 | if accelerator.is_main_process: |
838 | checkpointer.save_samples( | 764 | checkpointer.save_samples(global_step_offset, args.sample_steps) |
839 | 0, | ||
840 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
841 | 765 | ||
842 | local_progress_bar = tqdm( | 766 | local_progress_bar = tqdm( |
843 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 767 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
@@ -910,6 +834,8 @@ def main(): | |||
910 | else: | 834 | else: |
911 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 835 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
912 | 836 | ||
837 | acc = (model_pred == latents).float().mean() | ||
838 | |||
913 | accelerator.backward(loss) | 839 | accelerator.backward(loss) |
914 | 840 | ||
915 | optimizer.step() | 841 | optimizer.step() |
@@ -922,8 +848,8 @@ def main(): | |||
922 | text_encoder.get_input_embeddings( | 848 | text_encoder.get_input_embeddings( |
923 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | 849 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] |
924 | 850 | ||
925 | loss = loss.detach().item() | 851 | avg_loss.update(loss.detach_(), bsz) |
926 | train_loss += loss | 852 | avg_acc.update(acc.detach_(), bsz) |
927 | 853 | ||
928 | # Checks if the accelerator has performed an optimization step behind the scenes | 854 | # Checks if the accelerator has performed an optimization step behind the scenes |
929 | if accelerator.sync_gradients: | 855 | if accelerator.sync_gradients: |
@@ -932,7 +858,13 @@ def main(): | |||
932 | 858 | ||
933 | global_step += 1 | 859 | global_step += 1 |
934 | 860 | ||
935 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 861 | logs = { |
862 | "train/loss": avg_loss.avg.item(), | ||
863 | "train/acc": avg_acc.avg.item(), | ||
864 | "train/cur_loss": loss.item(), | ||
865 | "train/cur_acc": acc.item(), | ||
866 | "lr": lr_scheduler.get_last_lr()[0], | ||
867 | } | ||
936 | 868 | ||
937 | accelerator.log(logs, step=global_step) | 869 | accelerator.log(logs, step=global_step) |
938 | 870 | ||
@@ -941,12 +873,9 @@ def main(): | |||
941 | if global_step >= args.max_train_steps: | 873 | if global_step >= args.max_train_steps: |
942 | break | 874 | break |
943 | 875 | ||
944 | train_loss /= len(train_dataloader) | ||
945 | |||
946 | accelerator.wait_for_everyone() | 876 | accelerator.wait_for_everyone() |
947 | 877 | ||
948 | text_encoder.eval() | 878 | text_encoder.eval() |
949 | val_loss = 0.0 | ||
950 | 879 | ||
951 | with torch.inference_mode(): | 880 | with torch.inference_mode(): |
952 | for step, batch in enumerate(val_dataloader): | 881 | for step, batch in enumerate(val_dataloader): |
@@ -976,29 +905,37 @@ def main(): | |||
976 | 905 | ||
977 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 906 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
978 | 907 | ||
979 | loss = loss.detach().item() | 908 | acc = (model_pred == latents).float().mean() |
980 | val_loss += loss | 909 | |
910 | avg_loss_val.update(loss.detach_(), bsz) | ||
911 | avg_acc_val.update(acc.detach_(), bsz) | ||
981 | 912 | ||
982 | if accelerator.sync_gradients: | 913 | if accelerator.sync_gradients: |
983 | local_progress_bar.update(1) | 914 | local_progress_bar.update(1) |
984 | global_progress_bar.update(1) | 915 | global_progress_bar.update(1) |
985 | 916 | ||
986 | logs = {"val/loss": loss} | 917 | logs = { |
918 | "val/loss": avg_loss_val.avg.item(), | ||
919 | "val/acc": avg_acc_val.avg.item(), | ||
920 | "val/cur_loss": loss.item(), | ||
921 | "val/cur_acc": acc.item(), | ||
922 | } | ||
987 | local_progress_bar.set_postfix(**logs) | 923 | local_progress_bar.set_postfix(**logs) |
988 | 924 | ||
989 | val_loss /= len(val_dataloader) | 925 | accelerator.log({ |
990 | 926 | "val/loss": avg_loss_val.avg.item(), | |
991 | accelerator.log({"val/loss": val_loss}, step=global_step) | 927 | "val/acc": avg_acc_val.avg.item(), |
928 | }, step=global_step) | ||
992 | 929 | ||
993 | local_progress_bar.clear() | 930 | local_progress_bar.clear() |
994 | global_progress_bar.clear() | 931 | global_progress_bar.clear() |
995 | 932 | ||
996 | if accelerator.is_main_process: | 933 | if accelerator.is_main_process: |
997 | if min_val_loss > val_loss: | 934 | if avg_acc_val.avg.item() > max_acc_val: |
998 | accelerator.print( | 935 | accelerator.print( |
999 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 936 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
1000 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 937 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
1001 | min_val_loss = val_loss | 938 | max_acc_val = avg_acc_val.avg.item() |
1002 | 939 | ||
1003 | if (epoch + 1) % args.checkpoint_frequency == 0: | 940 | if (epoch + 1) % args.checkpoint_frequency == 0: |
1004 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 941 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
@@ -1007,9 +944,7 @@ def main(): | |||
1007 | }) | 944 | }) |
1008 | 945 | ||
1009 | if (epoch + 1) % args.sample_frequency == 0: | 946 | if (epoch + 1) % args.sample_frequency == 0: |
1010 | checkpointer.save_samples( | 947 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) |
1011 | global_step + global_step_offset, | ||
1012 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
1013 | 948 | ||
1014 | # Create the pipeline using using the trained modules and save it. | 949 | # Create the pipeline using using the trained modules and save it. |
1015 | if accelerator.is_main_process: | 950 | if accelerator.is_main_process: |
diff --git a/training/optimization.py b/training/optimization.py index 0e603fa..c501ed9 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,7 +6,7 @@ from diffusers.utils import logging | |||
6 | logger = logging.get_logger(__name__) | 6 | logger = logging.get_logger(__name__) |
7 | 7 | ||
8 | 8 | ||
9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): | 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.001, mid_point=0.4, last_epoch=-1): |
10 | """ | 10 | """ |
11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
diff --git a/training/util.py b/training/util.py new file mode 100644 index 0000000..e8d22ae --- /dev/null +++ b/training/util.py | |||
@@ -0,0 +1,131 @@ | |||
1 | from pathlib import Path | ||
2 | import json | ||
3 | |||
4 | import torch | ||
5 | from PIL import Image | ||
6 | |||
7 | |||
8 | def freeze_params(params): | ||
9 | for param in params: | ||
10 | param.requires_grad = False | ||
11 | |||
12 | |||
13 | def save_args(basepath: Path, args, extra={}): | ||
14 | info = {"args": vars(args)} | ||
15 | info["args"].update(extra) | ||
16 | with open(basepath.joinpath("args.json"), "w") as f: | ||
17 | json.dump(info, f, indent=4) | ||
18 | |||
19 | |||
20 | def make_grid(images, rows, cols): | ||
21 | w, h = images[0].size | ||
22 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
23 | for i, image in enumerate(images): | ||
24 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
25 | return grid | ||
26 | |||
27 | |||
28 | class AverageMeter: | ||
29 | def __init__(self, name=None): | ||
30 | self.name = name | ||
31 | self.reset() | ||
32 | |||
33 | def reset(self): | ||
34 | self.sum = self.count = self.avg = 0 | ||
35 | |||
36 | def update(self, val, n=1): | ||
37 | self.sum += val * n | ||
38 | self.count += n | ||
39 | self.avg = self.sum / self.count | ||
40 | |||
41 | |||
42 | class CheckpointerBase: | ||
43 | def __init__( | ||
44 | self, | ||
45 | datamodule, | ||
46 | output_dir: Path, | ||
47 | instance_identifier, | ||
48 | placeholder_token, | ||
49 | placeholder_token_id, | ||
50 | sample_image_size, | ||
51 | sample_batches, | ||
52 | sample_batch_size, | ||
53 | seed | ||
54 | ): | ||
55 | self.datamodule = datamodule | ||
56 | self.output_dir = output_dir | ||
57 | self.instance_identifier = instance_identifier | ||
58 | self.placeholder_token = placeholder_token | ||
59 | self.placeholder_token_id = placeholder_token_id | ||
60 | self.sample_image_size = sample_image_size | ||
61 | self.seed = seed or torch.random.seed() | ||
62 | self.sample_batches = sample_batches | ||
63 | self.sample_batch_size = sample_batch_size | ||
64 | |||
65 | @torch.no_grad() | ||
66 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | ||
67 | samples_path = Path(self.output_dir).joinpath("samples") | ||
68 | |||
69 | train_data = self.datamodule.train_dataloader() | ||
70 | val_data = self.datamodule.val_dataloader() | ||
71 | |||
72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
73 | stable_latents = torch.randn( | ||
74 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
75 | device=pipeline.device, | ||
76 | generator=generator, | ||
77 | ) | ||
78 | |||
79 | with torch.autocast("cuda"), torch.inference_mode(): | ||
80 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
81 | all_samples = [] | ||
82 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
83 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
84 | |||
85 | data_enum = enumerate(data) | ||
86 | |||
87 | batches = [ | ||
88 | batch | ||
89 | for j, batch in data_enum | ||
90 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
91 | ] | ||
92 | prompts = [ | ||
93 | prompt.format(identifier=self.instance_identifier) | ||
94 | for batch in batches | ||
95 | for prompt in batch["prompts"] | ||
96 | ] | ||
97 | nprompts = [ | ||
98 | prompt | ||
99 | for batch in batches | ||
100 | for prompt in batch["nprompts"] | ||
101 | ] | ||
102 | |||
103 | for i in range(self.sample_batches): | ||
104 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
105 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
106 | |||
107 | samples = pipeline( | ||
108 | prompt=prompt, | ||
109 | negative_prompt=nprompt, | ||
110 | height=self.sample_image_size, | ||
111 | width=self.sample_image_size, | ||
112 | image=latents[:len(prompt)] if latents is not None else None, | ||
113 | generator=generator if latents is not None else None, | ||
114 | guidance_scale=guidance_scale, | ||
115 | eta=eta, | ||
116 | num_inference_steps=num_inference_steps, | ||
117 | output_type='pil' | ||
118 | ).images | ||
119 | |||
120 | all_samples += samples | ||
121 | |||
122 | del samples | ||
123 | |||
124 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
125 | image_grid.save(file_path, quality=85) | ||
126 | |||
127 | del all_samples | ||
128 | del image_grid | ||
129 | |||
130 | del generator | ||
131 | del stable_latents | ||