diff options
| -rw-r--r-- | dreambooth.py | 49 | ||||
| -rw-r--r-- | infer.py | 2 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 34 | ||||
| -rw-r--r-- | textual_inversion.py | 119 |
4 files changed, 97 insertions, 107 deletions
diff --git a/dreambooth.py b/dreambooth.py index 3110c6d..9a6f70a 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -13,7 +13,7 @@ import torch.utils.checkpoint | |||
| 13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
| 14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
| 15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
| 19 | from PIL import Image | 19 | from PIL import Image |
| @@ -204,7 +204,7 @@ def parse_args(): | |||
| 204 | parser.add_argument( | 204 | parser.add_argument( |
| 205 | "--lr_warmup_epochs", | 205 | "--lr_warmup_epochs", |
| 206 | type=int, | 206 | type=int, |
| 207 | default=20, | 207 | default=10, |
| 208 | help="Number of steps for the warmup in the lr scheduler." | 208 | help="Number of steps for the warmup in the lr scheduler." |
| 209 | ) | 209 | ) |
| 210 | parser.add_argument( | 210 | parser.add_argument( |
| @@ -558,11 +558,11 @@ class Checkpointer: | |||
| 558 | def main(): | 558 | def main(): |
| 559 | args = parse_args() | 559 | args = parse_args() |
| 560 | 560 | ||
| 561 | # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | 561 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: |
| 562 | # raise ValueError( | 562 | raise ValueError( |
| 563 | # "Gradient accumulation is not supported when training the text encoder in distributed training. " | 563 | "Gradient accumulation is not supported when training the text encoder in distributed training. " |
| 564 | # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 564 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
| 565 | # ) | 565 | ) |
| 566 | 566 | ||
| 567 | instance_identifier = args.instance_identifier | 567 | instance_identifier = args.instance_identifier |
| 568 | 568 | ||
| @@ -645,9 +645,9 @@ def main(): | |||
| 645 | 645 | ||
| 646 | print(f"Token ID mappings:") | 646 | print(f"Token ID mappings:") |
| 647 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 647 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
| 648 | print(f"- {token_id} {token}") | ||
| 649 | |||
| 650 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") | 648 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") |
| 649 | embedding_source = "init" | ||
| 650 | |||
| 651 | if embedding_file.exists() and embedding_file.is_file(): | 651 | if embedding_file.exists() and embedding_file.is_file(): |
| 652 | embedding_data = torch.load(embedding_file, map_location="cpu") | 652 | embedding_data = torch.load(embedding_file, map_location="cpu") |
| 653 | 653 | ||
| @@ -656,8 +656,11 @@ def main(): | |||
| 656 | emb = emb.unsqueeze(0) | 656 | emb = emb.unsqueeze(0) |
| 657 | 657 | ||
| 658 | token_embeds[token_id] = emb | 658 | token_embeds[token_id] = emb |
| 659 | embedding_source = "file" | ||
| 659 | 660 | ||
| 660 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 661 | print(f"- {token_id} {token} ({embedding_source})") |
| 662 | |||
| 663 | original_token_embeds = token_embeds.clone().to(accelerator.device) | ||
| 661 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 664 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 662 | 665 | ||
| 663 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 666 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
| @@ -946,7 +949,7 @@ def main(): | |||
| 946 | sample_checkpoint = False | 949 | sample_checkpoint = False |
| 947 | 950 | ||
| 948 | for step, batch in enumerate(train_dataloader): | 951 | for step, batch in enumerate(train_dataloader): |
| 949 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | 952 | with accelerator.accumulate(unet): |
| 950 | # Convert images to latent space | 953 | # Convert images to latent space |
| 951 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 954 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 952 | latents = latents * 0.18215 | 955 | latents = latents * 0.18215 |
| @@ -997,16 +1000,6 @@ def main(): | |||
| 997 | 1000 | ||
| 998 | accelerator.backward(loss) | 1001 | accelerator.backward(loss) |
| 999 | 1002 | ||
| 1000 | if not args.train_text_encoder: | ||
| 1001 | # Keep the token embeddings fixed except the newly added | ||
| 1002 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
| 1003 | if accelerator.num_processes > 1: | ||
| 1004 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
| 1005 | else: | ||
| 1006 | token_embeds = text_encoder.get_input_embeddings().weight | ||
| 1007 | |||
| 1008 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
| 1009 | |||
| 1010 | if accelerator.sync_gradients: | 1003 | if accelerator.sync_gradients: |
| 1011 | params_to_clip = ( | 1004 | params_to_clip = ( |
| 1012 | itertools.chain(unet.parameters(), text_encoder.parameters()) | 1005 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
| @@ -1022,6 +1015,12 @@ def main(): | |||
| 1022 | ema_unet.step(unet) | 1015 | ema_unet.step(unet) |
| 1023 | optimizer.zero_grad(set_to_none=True) | 1016 | optimizer.zero_grad(set_to_none=True) |
| 1024 | 1017 | ||
| 1018 | if not args.train_text_encoder: | ||
| 1019 | # Let's make sure we don't update any embedding weights besides the newly added token | ||
| 1020 | with torch.no_grad(): | ||
| 1021 | text_encoder.get_input_embeddings( | ||
| 1022 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
| 1023 | |||
| 1025 | avg_loss.update(loss.detach_(), bsz) | 1024 | avg_loss.update(loss.detach_(), bsz) |
| 1026 | avg_acc.update(acc.detach_(), bsz) | 1025 | avg_acc.update(acc.detach_(), bsz) |
| 1027 | 1026 | ||
| @@ -1032,9 +1031,6 @@ def main(): | |||
| 1032 | 1031 | ||
| 1033 | global_step += 1 | 1032 | global_step += 1 |
| 1034 | 1033 | ||
| 1035 | if global_step % args.sample_frequency == 0: | ||
| 1036 | sample_checkpoint = True | ||
| 1037 | |||
| 1038 | logs = { | 1034 | logs = { |
| 1039 | "train/loss": avg_loss.avg.item(), | 1035 | "train/loss": avg_loss.avg.item(), |
| 1040 | "train/acc": avg_acc.avg.item(), | 1036 | "train/acc": avg_acc.avg.item(), |
| @@ -1117,8 +1113,9 @@ def main(): | |||
| 1117 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 1113 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 1118 | max_acc_val = avg_acc_val.avg.item() | 1114 | max_acc_val = avg_acc_val.avg.item() |
| 1119 | 1115 | ||
| 1120 | if sample_checkpoint and accelerator.is_main_process: | 1116 | if accelerator.is_main_process: |
| 1121 | checkpointer.save_samples(global_step, args.sample_steps) | 1117 | if epoch % args.sample_frequency == 0: |
| 1118 | checkpointer.save_samples(global_step, args.sample_steps) | ||
| 1122 | 1119 | ||
| 1123 | # Create the pipeline using using the trained modules and save it. | 1120 | # Create the pipeline using using the trained modules and save it. |
| 1124 | if accelerator.is_main_process: | 1121 | if accelerator.is_main_process: |
| @@ -31,7 +31,7 @@ torch.backends.cudnn.benchmark = True | |||
| 31 | 31 | ||
| 32 | 32 | ||
| 33 | default_args = { | 33 | default_args = { |
| 34 | "model": None, | 34 | "model": "stabilityai/stable-diffusion-2-1", |
| 35 | "precision": "fp32", | 35 | "precision": "fp32", |
| 36 | "ti_embeddings_dir": "embeddings_ti", | 36 | "ti_embeddings_dir": "embeddings_ti", |
| 37 | "output_dir": "output/inference", | 37 | "output_dir": "output/inference", |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 141b9a7..707b639 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -421,25 +421,29 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 421 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 421 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| 422 | 422 | ||
| 423 | # 7. Denoising loop | 423 | # 7. Denoising loop |
| 424 | for i, t in enumerate(self.progress_bar(timesteps)): | 424 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| 425 | # expand the latents if we are doing classifier free guidance | 425 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| 426 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 426 | for i, t in enumerate(timesteps): |
| 427 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 427 | # expand the latents if we are doing classifier free guidance |
| 428 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
| 429 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
| 428 | 430 | ||
| 429 | # predict the noise residual | 431 | # predict the noise residual |
| 430 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 432 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| 431 | 433 | ||
| 432 | # perform guidance | 434 | # perform guidance |
| 433 | if do_classifier_free_guidance: | 435 | if do_classifier_free_guidance: |
| 434 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 436 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| 435 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 437 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 436 | 438 | ||
| 437 | # compute the previous noisy sample x_t -> x_t-1 | 439 | # compute the previous noisy sample x_t -> x_t-1 |
| 438 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 440 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| 439 | 441 | ||
| 440 | # call the callback, if provided | 442 | # call the callback, if provided |
| 441 | if callback is not None and i % callback_steps == 0: | 443 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| 442 | callback(i, t, latents) | 444 | progress_bar.update() |
| 445 | if callback is not None and i % callback_steps == 0: | ||
| 446 | callback(i, t, latents) | ||
| 443 | 447 | ||
| 444 | # 8. Post-processing | 448 | # 8. Post-processing |
| 445 | image = self.decode_latents(latents) | 449 | image = self.decode_latents(latents) |
diff --git a/textual_inversion.py b/textual_inversion.py index a9c3326..11babd8 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -170,7 +170,7 @@ def parse_args(): | |||
| 170 | parser.add_argument( | 170 | parser.add_argument( |
| 171 | "--lr_scheduler", | 171 | "--lr_scheduler", |
| 172 | type=str, | 172 | type=str, |
| 173 | default="one_cycle", | 173 | default="constant_with_warmup", |
| 174 | help=( | 174 | help=( |
| 175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 176 | ' "constant", "constant_with_warmup", "one_cycle"]' | 176 | ' "constant", "constant_with_warmup", "one_cycle"]' |
| @@ -231,14 +231,14 @@ def parse_args(): | |||
| 231 | parser.add_argument( | 231 | parser.add_argument( |
| 232 | "--checkpoint_frequency", | 232 | "--checkpoint_frequency", |
| 233 | type=int, | 233 | type=int, |
| 234 | default=500, | 234 | default=5, |
| 235 | help="How often to save a checkpoint and sample image", | 235 | help="How often to save a checkpoint and sample image (in epochs)", |
| 236 | ) | 236 | ) |
| 237 | parser.add_argument( | 237 | parser.add_argument( |
| 238 | "--sample_frequency", | 238 | "--sample_frequency", |
| 239 | type=int, | 239 | type=int, |
| 240 | default=100, | 240 | default=1, |
| 241 | help="How often to save a checkpoint and sample image", | 241 | help="How often to save a checkpoint and sample image (in epochs)", |
| 242 | ) | 242 | ) |
| 243 | parser.add_argument( | 243 | parser.add_argument( |
| 244 | "--sample_image_size", | 244 | "--sample_image_size", |
| @@ -294,10 +294,9 @@ def parse_args(): | |||
| 294 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | 294 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" |
| 295 | ) | 295 | ) |
| 296 | parser.add_argument( | 296 | parser.add_argument( |
| 297 | "--resume_checkpoint", | 297 | "--global_step", |
| 298 | type=str, | 298 | type=int, |
| 299 | default=None, | 299 | default=0, |
| 300 | help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." | ||
| 301 | ) | 300 | ) |
| 302 | parser.add_argument( | 301 | parser.add_argument( |
| 303 | "--config", | 302 | "--config", |
| @@ -512,19 +511,10 @@ def main(): | |||
| 512 | if len(args.placeholder_token) != 0: | 511 | if len(args.placeholder_token) != 0: |
| 513 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | 512 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) |
| 514 | 513 | ||
| 515 | global_step_offset = 0 | 514 | global_step_offset = args.global_step |
| 516 | if args.resume_from is not None: | 515 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 517 | basepath = Path(args.resume_from) | 516 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| 518 | print("Resuming state from %s" % args.resume_from) | 517 | basepath.mkdir(parents=True, exist_ok=True) |
| 519 | with open(basepath.joinpath("resume.json"), 'r') as f: | ||
| 520 | state = json.load(f) | ||
| 521 | global_step_offset = state["args"].get("global_step", 0) | ||
| 522 | |||
| 523 | print("We've trained %d steps so far" % global_step_offset) | ||
| 524 | else: | ||
| 525 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 526 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | ||
| 527 | basepath.mkdir(parents=True, exist_ok=True) | ||
| 528 | 518 | ||
| 529 | accelerator = Accelerator( | 519 | accelerator = Accelerator( |
| 530 | log_with=LoggerType.TENSORBOARD, | 520 | log_with=LoggerType.TENSORBOARD, |
| @@ -557,6 +547,7 @@ def main(): | |||
| 557 | set_use_memory_efficient_attention_xformers(vae, True) | 547 | set_use_memory_efficient_attention_xformers(vae, True) |
| 558 | 548 | ||
| 559 | if args.gradient_checkpointing: | 549 | if args.gradient_checkpointing: |
| 550 | unet.enable_gradient_checkpointing() | ||
| 560 | text_encoder.gradient_checkpointing_enable() | 551 | text_encoder.gradient_checkpointing_enable() |
| 561 | 552 | ||
| 562 | print(f"Adding text embeddings: {args.placeholder_token}") | 553 | print(f"Adding text embeddings: {args.placeholder_token}") |
| @@ -577,14 +568,25 @@ def main(): | |||
| 577 | 568 | ||
| 578 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 569 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
| 579 | token_embeds = text_encoder.get_input_embeddings().weight.data | 570 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 580 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | ||
| 581 | 571 | ||
| 582 | if args.resume_checkpoint is not None: | 572 | if args.resume_from: |
| 583 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] | 573 | resumepath = Path(args.resume_from).joinpath("checkpoints") |
| 584 | else: | 574 | |
| 585 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 575 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
| 586 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 576 | embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") |
| 587 | token_embeds[token_id] = embeddings | 577 | embedding_data = torch.load(embedding_file, map_location="cpu") |
| 578 | |||
| 579 | emb = next(iter(embedding_data.values())) | ||
| 580 | if len(emb.shape) == 1: | ||
| 581 | emb = emb.unsqueeze(0) | ||
| 582 | |||
| 583 | token_embeds[token_id] = emb | ||
| 584 | |||
| 585 | original_token_embeds = token_embeds.clone().to(accelerator.device) | ||
| 586 | |||
| 587 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
| 588 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
| 589 | token_embeds[token_id] = embeddings | ||
| 588 | 590 | ||
| 589 | index_fixed_tokens = torch.arange(len(tokenizer)) | 591 | index_fixed_tokens = torch.arange(len(tokenizer)) |
| 590 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 592 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] |
| @@ -891,21 +893,16 @@ def main(): | |||
| 891 | 893 | ||
| 892 | accelerator.backward(loss) | 894 | accelerator.backward(loss) |
| 893 | 895 | ||
| 894 | # Keep the token embeddings fixed except the newly added | ||
| 895 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
| 896 | if accelerator.num_processes > 1: | ||
| 897 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
| 898 | else: | ||
| 899 | token_embeds = text_encoder.get_input_embeddings().weight | ||
| 900 | |||
| 901 | # Get the index for tokens that we want to freeze | ||
| 902 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
| 903 | |||
| 904 | optimizer.step() | 896 | optimizer.step() |
| 905 | if not accelerator.optimizer_step_was_skipped: | 897 | if not accelerator.optimizer_step_was_skipped: |
| 906 | lr_scheduler.step() | 898 | lr_scheduler.step() |
| 907 | optimizer.zero_grad(set_to_none=True) | 899 | optimizer.zero_grad(set_to_none=True) |
| 908 | 900 | ||
| 901 | # Let's make sure we don't update any embedding weights besides the newly added token | ||
| 902 | with torch.no_grad(): | ||
| 903 | text_encoder.get_input_embeddings( | ||
| 904 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
| 905 | |||
| 909 | loss = loss.detach().item() | 906 | loss = loss.detach().item() |
| 910 | train_loss += loss | 907 | train_loss += loss |
| 911 | 908 | ||
| @@ -916,19 +913,6 @@ def main(): | |||
| 916 | 913 | ||
| 917 | global_step += 1 | 914 | global_step += 1 |
| 918 | 915 | ||
| 919 | if global_step % args.sample_frequency == 0: | ||
| 920 | sample_checkpoint = True | ||
| 921 | |||
| 922 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
| 923 | local_progress_bar.clear() | ||
| 924 | global_progress_bar.clear() | ||
| 925 | |||
| 926 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
| 927 | save_args(basepath, args, { | ||
| 928 | "global_step": global_step + global_step_offset, | ||
| 929 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
| 930 | }) | ||
| 931 | |||
| 932 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 916 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 933 | 917 | ||
| 934 | accelerator.log(logs, step=global_step) | 918 | accelerator.log(logs, step=global_step) |
| @@ -992,24 +976,30 @@ def main(): | |||
| 992 | local_progress_bar.clear() | 976 | local_progress_bar.clear() |
| 993 | global_progress_bar.clear() | 977 | global_progress_bar.clear() |
| 994 | 978 | ||
| 995 | if min_val_loss > val_loss: | 979 | if accelerator.is_main_process: |
| 996 | accelerator.print( | 980 | if min_val_loss > val_loss: |
| 997 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 981 | accelerator.print( |
| 998 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 982 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 999 | min_val_loss = val_loss | 983 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
| 984 | min_val_loss = val_loss | ||
| 985 | |||
| 986 | if epoch % args.checkpoint_frequency == 0: | ||
| 987 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
| 988 | save_args(basepath, args, { | ||
| 989 | "global_step": global_step + global_step_offset | ||
| 990 | }) | ||
| 1000 | 991 | ||
| 1001 | if sample_checkpoint and accelerator.is_main_process: | 992 | if epoch % args.sample_frequency == 0: |
| 1002 | checkpointer.save_samples( | 993 | checkpointer.save_samples( |
| 1003 | global_step + global_step_offset, | 994 | global_step + global_step_offset, |
| 1004 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 995 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 1005 | 996 | ||
| 1006 | # Create the pipeline using using the trained modules and save it. | 997 | # Create the pipeline using using the trained modules and save it. |
| 1007 | if accelerator.is_main_process: | 998 | if accelerator.is_main_process: |
| 1008 | print("Finished! Saving final checkpoint and resume state.") | 999 | print("Finished! Saving final checkpoint and resume state.") |
| 1009 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1000 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 1010 | save_args(basepath, args, { | 1001 | save_args(basepath, args, { |
| 1011 | "global_step": global_step + global_step_offset, | 1002 | "global_step": global_step + global_step_offset |
| 1012 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
| 1013 | }) | 1003 | }) |
| 1014 | accelerator.end_training() | 1004 | accelerator.end_training() |
| 1015 | 1005 | ||
| @@ -1018,8 +1008,7 @@ def main(): | |||
| 1018 | print("Interrupted, saving checkpoint and resume state...") | 1008 | print("Interrupted, saving checkpoint and resume state...") |
| 1019 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1009 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 1020 | save_args(basepath, args, { | 1010 | save_args(basepath, args, { |
| 1021 | "global_step": global_step + global_step_offset, | 1011 | "global_step": global_step + global_step_offset |
| 1022 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
| 1023 | }) | 1012 | }) |
| 1024 | accelerator.end_training() | 1013 | accelerator.end_training() |
| 1025 | quit() | 1014 | quit() |
