diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
| commit | 68540b27849564994d921968a36faa9b997e626d (patch) | |
| tree | 8fbe834ab4c52f057cd114bbb0e786158f215acc /train_ti.py | |
| parent | Fix training (diff) | |
| download | textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2 textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip | |
Moved common training code into separate module
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 175 |
1 files changed, 55 insertions, 120 deletions
diff --git a/train_ti.py b/train_ti.py index 9616db6..198cf37 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -7,7 +7,6 @@ import logging | |||
| 7 | import json | 7 | import json |
| 8 | from pathlib import Path | 8 | from pathlib import Path |
| 9 | 9 | ||
| 10 | import numpy as np | ||
| 11 | import torch | 10 | import torch |
| 12 | import torch.nn.functional as F | 11 | import torch.nn.functional as F |
| 13 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
| @@ -17,7 +16,6 @@ from accelerate.logging import get_logger | |||
| 17 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 18 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 20 | from PIL import Image | ||
| 21 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 21 | from slugify import slugify |
| @@ -26,6 +24,7 @@ from common import load_text_embeddings, load_text_embedding | |||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
| 28 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
| 29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
| 30 | 29 | ||
| 31 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| @@ -138,7 +137,7 @@ def parse_args(): | |||
| 138 | parser.add_argument( | 137 | parser.add_argument( |
| 139 | "--tag_dropout", | 138 | "--tag_dropout", |
| 140 | type=float, | 139 | type=float, |
| 141 | default=0, | 140 | default=0.1, |
| 142 | help="Tag dropout probability.", | 141 | help="Tag dropout probability.", |
| 143 | ) | 142 | ) |
| 144 | parser.add_argument( | 143 | parser.add_argument( |
| @@ -355,27 +354,7 @@ def parse_args(): | |||
| 355 | return args | 354 | return args |
| 356 | 355 | ||
| 357 | 356 | ||
| 358 | def freeze_params(params): | 357 | class Checkpointer(CheckpointerBase): |
| 359 | for param in params: | ||
| 360 | param.requires_grad = False | ||
| 361 | |||
| 362 | |||
| 363 | def save_args(basepath: Path, args, extra={}): | ||
| 364 | info = {"args": vars(args)} | ||
| 365 | info["args"].update(extra) | ||
| 366 | with open(basepath.joinpath("args.json"), "w") as f: | ||
| 367 | json.dump(info, f, indent=4) | ||
| 368 | |||
| 369 | |||
| 370 | def make_grid(images, rows, cols): | ||
| 371 | w, h = images[0].size | ||
| 372 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
| 373 | for i, image in enumerate(images): | ||
| 374 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
| 375 | return grid | ||
| 376 | |||
| 377 | |||
| 378 | class Checkpointer: | ||
| 379 | def __init__( | 358 | def __init__( |
| 380 | self, | 359 | self, |
| 381 | datamodule, | 360 | datamodule, |
| @@ -394,21 +373,24 @@ class Checkpointer: | |||
| 394 | sample_batch_size, | 373 | sample_batch_size, |
| 395 | seed | 374 | seed |
| 396 | ): | 375 | ): |
| 397 | self.datamodule = datamodule | 376 | super().__init__( |
| 377 | datamodule=datamodule, | ||
| 378 | output_dir=output_dir, | ||
| 379 | instance_identifier=instance_identifier, | ||
| 380 | placeholder_token=placeholder_token, | ||
| 381 | placeholder_token_id=placeholder_token_id, | ||
| 382 | sample_image_size=sample_image_size, | ||
| 383 | seed=seed or torch.random.seed(), | ||
| 384 | sample_batches=sample_batches, | ||
| 385 | sample_batch_size=sample_batch_size | ||
| 386 | ) | ||
| 387 | |||
| 398 | self.accelerator = accelerator | 388 | self.accelerator = accelerator |
| 399 | self.vae = vae | 389 | self.vae = vae |
| 400 | self.unet = unet | 390 | self.unet = unet |
| 401 | self.tokenizer = tokenizer | 391 | self.tokenizer = tokenizer |
| 402 | self.text_encoder = text_encoder | 392 | self.text_encoder = text_encoder |
| 403 | self.scheduler = scheduler | 393 | self.scheduler = scheduler |
| 404 | self.instance_identifier = instance_identifier | ||
| 405 | self.placeholder_token = placeholder_token | ||
| 406 | self.placeholder_token_id = placeholder_token_id | ||
| 407 | self.output_dir = output_dir | ||
| 408 | self.sample_image_size = sample_image_size | ||
| 409 | self.seed = seed or torch.random.seed() | ||
| 410 | self.sample_batches = sample_batches | ||
| 411 | self.sample_batch_size = sample_batch_size | ||
| 412 | 394 | ||
| 413 | @torch.no_grad() | 395 | @torch.no_grad() |
| 414 | def checkpoint(self, step, postfix): | 396 | def checkpoint(self, step, postfix): |
| @@ -431,9 +413,7 @@ class Checkpointer: | |||
| 431 | del learned_embeds | 413 | del learned_embeds |
| 432 | 414 | ||
| 433 | @torch.no_grad() | 415 | @torch.no_grad() |
| 434 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 416 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 435 | samples_path = Path(self.output_dir).joinpath("samples") | ||
| 436 | |||
| 437 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 417 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 438 | 418 | ||
| 439 | # Save a sample image | 419 | # Save a sample image |
| @@ -446,71 +426,10 @@ class Checkpointer: | |||
| 446 | ).to(self.accelerator.device) | 426 | ).to(self.accelerator.device) |
| 447 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 427 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 448 | 428 | ||
| 449 | train_data = self.datamodule.train_dataloader() | 429 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 450 | val_data = self.datamodule.val_dataloader() | ||
| 451 | |||
| 452 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
| 453 | stable_latents = torch.randn( | ||
| 454 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | ||
| 455 | device=pipeline.device, | ||
| 456 | generator=generator, | ||
| 457 | ) | ||
| 458 | |||
| 459 | with torch.autocast("cuda"), torch.inference_mode(): | ||
| 460 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
| 461 | all_samples = [] | ||
| 462 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
| 463 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 464 | |||
| 465 | data_enum = enumerate(data) | ||
| 466 | |||
| 467 | batches = [ | ||
| 468 | batch | ||
| 469 | for j, batch in data_enum | ||
| 470 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 471 | ] | ||
| 472 | prompts = [ | ||
| 473 | prompt.format(identifier=self.instance_identifier) | ||
| 474 | for batch in batches | ||
| 475 | for prompt in batch["prompts"] | ||
| 476 | ] | ||
| 477 | nprompts = [ | ||
| 478 | prompt | ||
| 479 | for batch in batches | ||
| 480 | for prompt in batch["nprompts"] | ||
| 481 | ] | ||
| 482 | |||
| 483 | for i in range(self.sample_batches): | ||
| 484 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 485 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 486 | |||
| 487 | samples = pipeline( | ||
| 488 | prompt=prompt, | ||
| 489 | negative_prompt=nprompt, | ||
| 490 | height=self.sample_image_size, | ||
| 491 | width=self.sample_image_size, | ||
| 492 | image=latents[:len(prompt)] if latents is not None else None, | ||
| 493 | generator=generator if latents is not None else None, | ||
| 494 | guidance_scale=guidance_scale, | ||
| 495 | eta=eta, | ||
| 496 | num_inference_steps=num_inference_steps, | ||
| 497 | output_type='pil' | ||
| 498 | ).images | ||
| 499 | |||
| 500 | all_samples += samples | ||
| 501 | |||
| 502 | del samples | ||
| 503 | |||
| 504 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
| 505 | image_grid.save(file_path, quality=85) | ||
| 506 | |||
| 507 | del all_samples | ||
| 508 | del image_grid | ||
| 509 | 430 | ||
| 510 | del text_encoder | 431 | del text_encoder |
| 511 | del pipeline | 432 | del pipeline |
| 512 | del generator | ||
| 513 | del stable_latents | ||
| 514 | 433 | ||
| 515 | if torch.cuda.is_available(): | 434 | if torch.cuda.is_available(): |
| 516 | torch.cuda.empty_cache() | 435 | torch.cuda.empty_cache() |
| @@ -814,7 +733,14 @@ def main(): | |||
| 814 | # Only show the progress bar once on each machine. | 733 | # Only show the progress bar once on each machine. |
| 815 | 734 | ||
| 816 | global_step = 0 | 735 | global_step = 0 |
| 817 | min_val_loss = np.inf | 736 | |
| 737 | avg_loss = AverageMeter() | ||
| 738 | avg_acc = AverageMeter() | ||
| 739 | |||
| 740 | avg_loss_val = AverageMeter() | ||
| 741 | avg_acc_val = AverageMeter() | ||
| 742 | |||
| 743 | max_acc_val = 0.0 | ||
| 818 | 744 | ||
| 819 | checkpointer = Checkpointer( | 745 | checkpointer = Checkpointer( |
| 820 | datamodule=datamodule, | 746 | datamodule=datamodule, |
| @@ -835,9 +761,7 @@ def main(): | |||
| 835 | ) | 761 | ) |
| 836 | 762 | ||
| 837 | if accelerator.is_main_process: | 763 | if accelerator.is_main_process: |
| 838 | checkpointer.save_samples( | 764 | checkpointer.save_samples(global_step_offset, args.sample_steps) |
| 839 | 0, | ||
| 840 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 841 | 765 | ||
| 842 | local_progress_bar = tqdm( | 766 | local_progress_bar = tqdm( |
| 843 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 767 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -910,6 +834,8 @@ def main(): | |||
| 910 | else: | 834 | else: |
| 911 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 835 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 912 | 836 | ||
| 837 | acc = (model_pred == latents).float().mean() | ||
| 838 | |||
| 913 | accelerator.backward(loss) | 839 | accelerator.backward(loss) |
| 914 | 840 | ||
| 915 | optimizer.step() | 841 | optimizer.step() |
| @@ -922,8 +848,8 @@ def main(): | |||
| 922 | text_encoder.get_input_embeddings( | 848 | text_encoder.get_input_embeddings( |
| 923 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | 849 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] |
| 924 | 850 | ||
| 925 | loss = loss.detach().item() | 851 | avg_loss.update(loss.detach_(), bsz) |
| 926 | train_loss += loss | 852 | avg_acc.update(acc.detach_(), bsz) |
| 927 | 853 | ||
| 928 | # Checks if the accelerator has performed an optimization step behind the scenes | 854 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 929 | if accelerator.sync_gradients: | 855 | if accelerator.sync_gradients: |
| @@ -932,7 +858,13 @@ def main(): | |||
| 932 | 858 | ||
| 933 | global_step += 1 | 859 | global_step += 1 |
| 934 | 860 | ||
| 935 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 861 | logs = { |
| 862 | "train/loss": avg_loss.avg.item(), | ||
| 863 | "train/acc": avg_acc.avg.item(), | ||
| 864 | "train/cur_loss": loss.item(), | ||
| 865 | "train/cur_acc": acc.item(), | ||
| 866 | "lr": lr_scheduler.get_last_lr()[0], | ||
| 867 | } | ||
| 936 | 868 | ||
| 937 | accelerator.log(logs, step=global_step) | 869 | accelerator.log(logs, step=global_step) |
| 938 | 870 | ||
| @@ -941,12 +873,9 @@ def main(): | |||
| 941 | if global_step >= args.max_train_steps: | 873 | if global_step >= args.max_train_steps: |
| 942 | break | 874 | break |
| 943 | 875 | ||
| 944 | train_loss /= len(train_dataloader) | ||
| 945 | |||
| 946 | accelerator.wait_for_everyone() | 876 | accelerator.wait_for_everyone() |
| 947 | 877 | ||
| 948 | text_encoder.eval() | 878 | text_encoder.eval() |
| 949 | val_loss = 0.0 | ||
| 950 | 879 | ||
| 951 | with torch.inference_mode(): | 880 | with torch.inference_mode(): |
| 952 | for step, batch in enumerate(val_dataloader): | 881 | for step, batch in enumerate(val_dataloader): |
| @@ -976,29 +905,37 @@ def main(): | |||
| 976 | 905 | ||
| 977 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 906 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 978 | 907 | ||
| 979 | loss = loss.detach().item() | 908 | acc = (model_pred == latents).float().mean() |
| 980 | val_loss += loss | 909 | |
| 910 | avg_loss_val.update(loss.detach_(), bsz) | ||
| 911 | avg_acc_val.update(acc.detach_(), bsz) | ||
| 981 | 912 | ||
| 982 | if accelerator.sync_gradients: | 913 | if accelerator.sync_gradients: |
| 983 | local_progress_bar.update(1) | 914 | local_progress_bar.update(1) |
| 984 | global_progress_bar.update(1) | 915 | global_progress_bar.update(1) |
| 985 | 916 | ||
| 986 | logs = {"val/loss": loss} | 917 | logs = { |
| 918 | "val/loss": avg_loss_val.avg.item(), | ||
| 919 | "val/acc": avg_acc_val.avg.item(), | ||
| 920 | "val/cur_loss": loss.item(), | ||
| 921 | "val/cur_acc": acc.item(), | ||
| 922 | } | ||
| 987 | local_progress_bar.set_postfix(**logs) | 923 | local_progress_bar.set_postfix(**logs) |
| 988 | 924 | ||
| 989 | val_loss /= len(val_dataloader) | 925 | accelerator.log({ |
| 990 | 926 | "val/loss": avg_loss_val.avg.item(), | |
| 991 | accelerator.log({"val/loss": val_loss}, step=global_step) | 927 | "val/acc": avg_acc_val.avg.item(), |
| 928 | }, step=global_step) | ||
| 992 | 929 | ||
| 993 | local_progress_bar.clear() | 930 | local_progress_bar.clear() |
| 994 | global_progress_bar.clear() | 931 | global_progress_bar.clear() |
| 995 | 932 | ||
| 996 | if accelerator.is_main_process: | 933 | if accelerator.is_main_process: |
| 997 | if min_val_loss > val_loss: | 934 | if avg_acc_val.avg.item() > max_acc_val: |
| 998 | accelerator.print( | 935 | accelerator.print( |
| 999 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 936 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 1000 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 937 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
| 1001 | min_val_loss = val_loss | 938 | max_acc_val = avg_acc_val.avg.item() |
| 1002 | 939 | ||
| 1003 | if (epoch + 1) % args.checkpoint_frequency == 0: | 940 | if (epoch + 1) % args.checkpoint_frequency == 0: |
| 1004 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 941 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
| @@ -1007,9 +944,7 @@ def main(): | |||
| 1007 | }) | 944 | }) |
| 1008 | 945 | ||
| 1009 | if (epoch + 1) % args.sample_frequency == 0: | 946 | if (epoch + 1) % args.sample_frequency == 0: |
| 1010 | checkpointer.save_samples( | 947 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) |
| 1011 | global_step + global_step_offset, | ||
| 1012 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 1013 | 948 | ||
| 1014 | # Create the pipeline using using the trained modules and save it. | 949 | # Create the pipeline using using the trained modules and save it. |
| 1015 | if accelerator.is_main_process: | 950 | if accelerator.is_main_process: |
