diff options
author | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
commit | adc52fb8821a496bc8d78235bf10466b39df03e0 (patch) | |
tree | 8a6337a6ac10cbe76c55514ab559c647e69fb1aa /train_dreambooth.py | |
parent | Fixed accuracy calc, other improvements (diff) | |
download | textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2 textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip |
Updates
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 228 |
1 files changed, 124 insertions, 104 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1ebcfe3..b07de31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -3,7 +3,6 @@ import itertools | |||
3 | import math | 3 | import math |
4 | import datetime | 4 | import datetime |
5 | import logging | 5 | import logging |
6 | import json | ||
7 | from pathlib import Path | 6 | from pathlib import Path |
8 | 7 | ||
9 | import torch | 8 | import torch |
@@ -15,18 +14,21 @@ from accelerate.logging import get_logger | |||
15 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
17 | import matplotlib.pyplot as plt | ||
18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
19 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings, load_config | 23 | from common import load_config, load_embeddings_from_dir |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.ti import patch_trainable_embeddings | 27 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings | ||
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
31 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
30 | 32 | ||
31 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
32 | 34 | ||
@@ -106,6 +108,12 @@ def parse_args(): | |||
106 | help="Tag dropout probability.", | 108 | help="Tag dropout probability.", |
107 | ) | 109 | ) |
108 | parser.add_argument( | 110 | parser.add_argument( |
111 | "--vector_shuffle", | ||
112 | type=str, | ||
113 | default="auto", | ||
114 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
115 | ) | ||
116 | parser.add_argument( | ||
109 | "--num_class_images", | 117 | "--num_class_images", |
110 | type=int, | 118 | type=int, |
111 | default=1, | 119 | default=1, |
@@ -193,13 +201,12 @@ def parse_args(): | |||
193 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 201 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
194 | ) | 202 | ) |
195 | parser.add_argument( | 203 | parser.add_argument( |
196 | "--learning_rate_unet", | 204 | "--find_lr", |
197 | type=float, | 205 | action="store_true", |
198 | default=2e-6, | 206 | help="Automatically find a learning rate (no training).", |
199 | help="Initial learning rate (after the potential warmup period) to use.", | ||
200 | ) | 207 | ) |
201 | parser.add_argument( | 208 | parser.add_argument( |
202 | "--learning_rate_text", | 209 | "--learning_rate", |
203 | type=float, | 210 | type=float, |
204 | default=2e-6, | 211 | default=2e-6, |
205 | help="Initial learning rate (after the potential warmup period) to use.", | 212 | help="Initial learning rate (after the potential warmup period) to use.", |
@@ -546,9 +553,9 @@ def main(): | |||
546 | 553 | ||
547 | # Load the tokenizer and add the placeholder token as a additional special token | 554 | # Load the tokenizer and add the placeholder token as a additional special token |
548 | if args.tokenizer_name: | 555 | if args.tokenizer_name: |
549 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 556 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
550 | elif args.pretrained_model_name_or_path: | 557 | elif args.pretrained_model_name_or_path: |
551 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 558 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
552 | 559 | ||
553 | # Load models and create wrapper for stable diffusion | 560 | # Load models and create wrapper for stable diffusion |
554 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 561 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -558,6 +565,8 @@ def main(): | |||
558 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
559 | args.pretrained_model_name_or_path, subfolder='scheduler') | 566 | args.pretrained_model_name_or_path, subfolder='scheduler') |
560 | 567 | ||
568 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
569 | |||
561 | vae.enable_slicing() | 570 | vae.enable_slicing() |
562 | vae.set_use_memory_efficient_attention_xformers(True) | 571 | vae.set_use_memory_efficient_attention_xformers(True) |
563 | unet.set_use_memory_efficient_attention_xformers(True) | 572 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -576,46 +585,42 @@ def main(): | |||
576 | device=accelerator.device | 585 | device=accelerator.device |
577 | ) | 586 | ) |
578 | 587 | ||
579 | # Freeze text_encoder and vae | 588 | embeddings = patch_managed_embeddings(text_encoder) |
580 | vae.requires_grad_(False) | ||
581 | 589 | ||
582 | if args.embeddings_dir is not None: | 590 | if args.embeddings_dir is not None: |
583 | embeddings_dir = Path(args.embeddings_dir) | 591 | embeddings_dir = Path(args.embeddings_dir) |
584 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 592 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
585 | raise ValueError("--embeddings_dir must point to an existing directory") | 593 | raise ValueError("--embeddings_dir must point to an existing directory") |
586 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 594 | |
587 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 595 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
596 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | ||
588 | 597 | ||
589 | if len(args.placeholder_token) != 0: | 598 | if len(args.placeholder_token) != 0: |
590 | # Convert the initializer_token, placeholder_token to ids | 599 | # Convert the initializer_token, placeholder_token to ids |
591 | initializer_token_ids = torch.stack([ | 600 | initializer_token_ids = [ |
592 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 601 | tokenizer.encode(token, add_special_tokens=False) |
593 | for token in args.initializer_token | 602 | for token in args.initializer_token |
594 | ]) | 603 | ] |
595 | |||
596 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
597 | print(f"Added {num_added_tokens} new tokens.") | ||
598 | 604 | ||
599 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 605 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
606 | embeddings.resize(len(tokenizer)) | ||
600 | 607 | ||
601 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 608 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
602 | text_encoder.resize_token_embeddings(len(tokenizer)) | 609 | embeddings.add_embed(new_token.ids, init_ids) |
603 | 610 | ||
604 | token_embeds = text_encoder.get_input_embeddings().weight.data | 611 | print(f"Added {len(new_tokens)} new tokens.") |
605 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
606 | |||
607 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
608 | token_embeds[token_id] = embeddings | ||
609 | else: | 612 | else: |
610 | placeholder_token_id = [] | 613 | placeholder_token_id = [] |
611 | 614 | ||
615 | vae.requires_grad_(False) | ||
616 | |||
612 | if args.train_text_encoder: | 617 | if args.train_text_encoder: |
613 | print(f"Training entire text encoder.") | 618 | print(f"Training entire text encoder.") |
619 | |||
620 | unpatch_managed_embeddings(text_encoder) | ||
614 | else: | 621 | else: |
615 | print(f"Training added text embeddings") | 622 | print(f"Training added text embeddings") |
616 | 623 | ||
617 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | ||
618 | |||
619 | text_encoder.text_model.encoder.requires_grad_(False) | 624 | text_encoder.text_model.encoder.requires_grad_(False) |
620 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 625 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
621 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 626 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
@@ -624,15 +629,14 @@ def main(): | |||
624 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 629 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
625 | 630 | ||
626 | if args.scale_lr: | 631 | if args.scale_lr: |
627 | args.learning_rate_unet = ( | 632 | args.learning_rate = ( |
628 | args.learning_rate_unet * args.gradient_accumulation_steps * | 633 | args.learning_rate * args.gradient_accumulation_steps * |
629 | args.train_batch_size * accelerator.num_processes | ||
630 | ) | ||
631 | args.learning_rate_text = ( | ||
632 | args.learning_rate_text * args.gradient_accumulation_steps * | ||
633 | args.train_batch_size * accelerator.num_processes | 634 | args.train_batch_size * accelerator.num_processes |
634 | ) | 635 | ) |
635 | 636 | ||
637 | if args.find_lr: | ||
638 | args.learning_rate = 1e2 | ||
639 | |||
636 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 640 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
637 | if args.use_8bit_adam: | 641 | if args.use_8bit_adam: |
638 | try: | 642 | try: |
@@ -647,20 +651,19 @@ def main(): | |||
647 | if args.train_text_encoder: | 651 | if args.train_text_encoder: |
648 | text_encoder_params_to_optimize = text_encoder.parameters() | 652 | text_encoder_params_to_optimize = text_encoder.parameters() |
649 | else: | 653 | else: |
650 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() | 654 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
651 | 655 | ||
652 | # Initialize the optimizer | 656 | # Initialize the optimizer |
653 | optimizer = optimizer_class( | 657 | optimizer = optimizer_class( |
654 | [ | 658 | [ |
655 | { | 659 | { |
656 | 'params': unet.parameters(), | 660 | 'params': unet.parameters(), |
657 | 'lr': args.learning_rate_unet, | ||
658 | }, | 661 | }, |
659 | { | 662 | { |
660 | 'params': text_encoder_params_to_optimize, | 663 | 'params': text_encoder_params_to_optimize, |
661 | 'lr': args.learning_rate_text, | ||
662 | } | 664 | } |
663 | ], | 665 | ], |
666 | lr=args.learning_rate, | ||
664 | betas=(args.adam_beta1, args.adam_beta2), | 667 | betas=(args.adam_beta1, args.adam_beta2), |
665 | weight_decay=args.adam_weight_decay, | 668 | weight_decay=args.adam_weight_decay, |
666 | eps=args.adam_epsilon, | 669 | eps=args.adam_epsilon, |
@@ -824,6 +827,58 @@ def main(): | |||
824 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 827 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
825 | val_steps = num_val_steps_per_epoch * num_epochs | 828 | val_steps = num_val_steps_per_epoch * num_epochs |
826 | 829 | ||
830 | def loop(batch): | ||
831 | # Convert images to latent space | ||
832 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
833 | latents = latents * 0.18215 | ||
834 | |||
835 | # Sample noise that we'll add to the latents | ||
836 | noise = torch.randn_like(latents) | ||
837 | bsz = latents.shape[0] | ||
838 | # Sample a random timestep for each image | ||
839 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
840 | (bsz,), device=latents.device) | ||
841 | timesteps = timesteps.long() | ||
842 | |||
843 | # Add noise to the latents according to the noise magnitude at each timestep | ||
844 | # (this is the forward diffusion process) | ||
845 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
846 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
847 | |||
848 | # Get the text embedding for conditioning | ||
849 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
850 | |||
851 | # Predict the noise residual | ||
852 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
853 | |||
854 | # Get the target for loss depending on the prediction type | ||
855 | if noise_scheduler.config.prediction_type == "epsilon": | ||
856 | target = noise | ||
857 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
858 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
859 | else: | ||
860 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
861 | |||
862 | if args.num_class_images != 0: | ||
863 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
864 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
865 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
866 | |||
867 | # Compute instance loss | ||
868 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
869 | |||
870 | # Compute prior loss | ||
871 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
872 | |||
873 | # Add the prior loss to the instance loss. | ||
874 | loss = loss + args.prior_loss_weight * prior_loss | ||
875 | else: | ||
876 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
877 | |||
878 | acc = (model_pred == target).float().mean() | ||
879 | |||
880 | return loss, acc, bsz | ||
881 | |||
827 | # We need to initialize the trackers we use, and also store our configuration. | 882 | # We need to initialize the trackers we use, and also store our configuration. |
828 | # The trackers initializes automatically on the main process. | 883 | # The trackers initializes automatically on the main process. |
829 | if accelerator.is_main_process: | 884 | if accelerator.is_main_process: |
@@ -836,6 +891,15 @@ def main(): | |||
836 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | 891 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
837 | accelerator.init_trackers("dreambooth", config=config) | 892 | accelerator.init_trackers("dreambooth", config=config) |
838 | 893 | ||
894 | if args.find_lr: | ||
895 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
896 | lr_finder.run(min_lr=1e-4) | ||
897 | |||
898 | plt.savefig(basepath.joinpath("lr.png")) | ||
899 | plt.close() | ||
900 | |||
901 | quit() | ||
902 | |||
839 | # Train! | 903 | # Train! |
840 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 904 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
841 | 905 | ||
@@ -893,58 +957,6 @@ def main(): | |||
893 | ) | 957 | ) |
894 | global_progress_bar.set_description("Total progress") | 958 | global_progress_bar.set_description("Total progress") |
895 | 959 | ||
896 | def loop(batch): | ||
897 | # Convert images to latent space | ||
898 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
899 | latents = latents * 0.18215 | ||
900 | |||
901 | # Sample noise that we'll add to the latents | ||
902 | noise = torch.randn_like(latents) | ||
903 | bsz = latents.shape[0] | ||
904 | # Sample a random timestep for each image | ||
905 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
906 | (bsz,), device=latents.device) | ||
907 | timesteps = timesteps.long() | ||
908 | |||
909 | # Add noise to the latents according to the noise magnitude at each timestep | ||
910 | # (this is the forward diffusion process) | ||
911 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
912 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
913 | |||
914 | # Get the text embedding for conditioning | ||
915 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
916 | |||
917 | # Predict the noise residual | ||
918 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
919 | |||
920 | # Get the target for loss depending on the prediction type | ||
921 | if noise_scheduler.config.prediction_type == "epsilon": | ||
922 | target = noise | ||
923 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
924 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
925 | else: | ||
926 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
927 | |||
928 | if args.num_class_images != 0: | ||
929 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
930 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
931 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
932 | |||
933 | # Compute instance loss | ||
934 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
935 | |||
936 | # Compute prior loss | ||
937 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
938 | |||
939 | # Add the prior loss to the instance loss. | ||
940 | loss = loss + args.prior_loss_weight * prior_loss | ||
941 | else: | ||
942 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
943 | |||
944 | acc = (model_pred == target).float().mean() | ||
945 | |||
946 | return loss, acc, bsz | ||
947 | |||
948 | try: | 960 | try: |
949 | for epoch in range(num_epochs): | 961 | for epoch in range(num_epochs): |
950 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 962 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -993,8 +1005,7 @@ def main(): | |||
993 | "train/acc": avg_acc.avg.item(), | 1005 | "train/acc": avg_acc.avg.item(), |
994 | "train/cur_loss": loss.item(), | 1006 | "train/cur_loss": loss.item(), |
995 | "train/cur_acc": acc.item(), | 1007 | "train/cur_acc": acc.item(), |
996 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1008 | "lr": lr_scheduler.get_last_lr()[0] |
997 | "lr/text": lr_scheduler.get_last_lr()[1] | ||
998 | } | 1009 | } |
999 | if args.use_ema: | 1010 | if args.use_ema: |
1000 | logs["ema_decay"] = 1 - ema_unet.decay | 1011 | logs["ema_decay"] = 1 - ema_unet.decay |
@@ -1011,12 +1022,21 @@ def main(): | |||
1011 | unet.eval() | 1022 | unet.eval() |
1012 | text_encoder.eval() | 1023 | text_encoder.eval() |
1013 | 1024 | ||
1025 | cur_loss_val = AverageMeter() | ||
1026 | cur_acc_val = AverageMeter() | ||
1027 | |||
1014 | with torch.inference_mode(): | 1028 | with torch.inference_mode(): |
1015 | for step, batch in enumerate(val_dataloader): | 1029 | for step, batch in enumerate(val_dataloader): |
1016 | loss, acc, bsz = loop(batch) | 1030 | loss, acc, bsz = loop(batch) |
1017 | 1031 | ||
1018 | avg_loss_val.update(loss.detach_(), bsz) | 1032 | loss = loss.detach_() |
1019 | avg_acc_val.update(acc.detach_(), bsz) | 1033 | acc = acc.detach_() |
1034 | |||
1035 | cur_loss_val.update(loss, bsz) | ||
1036 | cur_acc_val.update(acc, bsz) | ||
1037 | |||
1038 | avg_loss_val.update(loss, bsz) | ||
1039 | avg_acc_val.update(acc, bsz) | ||
1020 | 1040 | ||
1021 | local_progress_bar.update(1) | 1041 | local_progress_bar.update(1) |
1022 | global_progress_bar.update(1) | 1042 | global_progress_bar.update(1) |
@@ -1029,20 +1049,20 @@ def main(): | |||
1029 | } | 1049 | } |
1030 | local_progress_bar.set_postfix(**logs) | 1050 | local_progress_bar.set_postfix(**logs) |
1031 | 1051 | ||
1032 | accelerator.log({ | 1052 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
1033 | "val/loss": avg_loss_val.avg.item(), | 1053 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
1034 | "val/acc": avg_acc_val.avg.item(), | 1054 | |
1035 | }, step=global_step) | 1055 | accelerator.log(logs, step=global_step) |
1036 | 1056 | ||
1037 | local_progress_bar.clear() | 1057 | local_progress_bar.clear() |
1038 | global_progress_bar.clear() | 1058 | global_progress_bar.clear() |
1039 | 1059 | ||
1040 | if avg_acc_val.avg.item() > max_acc_val: | ||
1041 | accelerator.print( | ||
1042 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
1043 | max_acc_val = avg_acc_val.avg.item() | ||
1044 | |||
1045 | if accelerator.is_main_process: | 1060 | if accelerator.is_main_process: |
1061 | if avg_acc_val.avg.item() > max_acc_val: | ||
1062 | accelerator.print( | ||
1063 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
1064 | max_acc_val = avg_acc_val.avg.item() | ||
1065 | |||
1046 | if (epoch + 1) % args.sample_frequency == 0: | 1066 | if (epoch + 1) % args.sample_frequency == 0: |
1047 | checkpointer.save_samples(global_step, args.sample_steps) | 1067 | checkpointer.save_samples(global_step, args.sample_steps) |
1048 | 1068 | ||