diff options
author | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
commit | 68540b27849564994d921968a36faa9b997e626d (patch) | |
tree | 8fbe834ab4c52f057cd114bbb0e786158f215acc /train_ti.py | |
parent | Fix training (diff) | |
download | textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2 textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip |
Moved common training code into separate module
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 175 |
1 files changed, 55 insertions, 120 deletions
diff --git a/train_ti.py b/train_ti.py index 9616db6..198cf37 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -7,7 +7,6 @@ import logging | |||
7 | import json | 7 | import json |
8 | from pathlib import Path | 8 | from pathlib import Path |
9 | 9 | ||
10 | import numpy as np | ||
11 | import torch | 10 | import torch |
12 | import torch.nn.functional as F | 11 | import torch.nn.functional as F |
13 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
@@ -17,7 +16,6 @@ from accelerate.logging import get_logger | |||
17 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 18 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from PIL import Image | ||
21 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
22 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 21 | from slugify import slugify |
@@ -26,6 +24,7 @@ from common import load_text_embeddings, load_text_embedding | |||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
28 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
30 | 29 | ||
31 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
@@ -138,7 +137,7 @@ def parse_args(): | |||
138 | parser.add_argument( | 137 | parser.add_argument( |
139 | "--tag_dropout", | 138 | "--tag_dropout", |
140 | type=float, | 139 | type=float, |
141 | default=0, | 140 | default=0.1, |
142 | help="Tag dropout probability.", | 141 | help="Tag dropout probability.", |
143 | ) | 142 | ) |
144 | parser.add_argument( | 143 | parser.add_argument( |
@@ -355,27 +354,7 @@ def parse_args(): | |||
355 | return args | 354 | return args |
356 | 355 | ||
357 | 356 | ||
358 | def freeze_params(params): | 357 | class Checkpointer(CheckpointerBase): |
359 | for param in params: | ||
360 | param.requires_grad = False | ||
361 | |||
362 | |||
363 | def save_args(basepath: Path, args, extra={}): | ||
364 | info = {"args": vars(args)} | ||
365 | info["args"].update(extra) | ||
366 | with open(basepath.joinpath("args.json"), "w") as f: | ||
367 | json.dump(info, f, indent=4) | ||
368 | |||
369 | |||
370 | def make_grid(images, rows, cols): | ||
371 | w, h = images[0].size | ||
372 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
373 | for i, image in enumerate(images): | ||
374 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
375 | return grid | ||
376 | |||
377 | |||
378 | class Checkpointer: | ||
379 | def __init__( | 358 | def __init__( |
380 | self, | 359 | self, |
381 | datamodule, | 360 | datamodule, |
@@ -394,21 +373,24 @@ class Checkpointer: | |||
394 | sample_batch_size, | 373 | sample_batch_size, |
395 | seed | 374 | seed |
396 | ): | 375 | ): |
397 | self.datamodule = datamodule | 376 | super().__init__( |
377 | datamodule=datamodule, | ||
378 | output_dir=output_dir, | ||
379 | instance_identifier=instance_identifier, | ||
380 | placeholder_token=placeholder_token, | ||
381 | placeholder_token_id=placeholder_token_id, | ||
382 | sample_image_size=sample_image_size, | ||
383 | seed=seed or torch.random.seed(), | ||
384 | sample_batches=sample_batches, | ||
385 | sample_batch_size=sample_batch_size | ||
386 | ) | ||
387 | |||
398 | self.accelerator = accelerator | 388 | self.accelerator = accelerator |
399 | self.vae = vae | 389 | self.vae = vae |
400 | self.unet = unet | 390 | self.unet = unet |
401 | self.tokenizer = tokenizer | 391 | self.tokenizer = tokenizer |
402 | self.text_encoder = text_encoder | 392 | self.text_encoder = text_encoder |
403 | self.scheduler = scheduler | 393 | self.scheduler = scheduler |
404 | self.instance_identifier = instance_identifier | ||
405 | self.placeholder_token = placeholder_token | ||
406 | self.placeholder_token_id = placeholder_token_id | ||
407 | self.output_dir = output_dir | ||
408 | self.sample_image_size = sample_image_size | ||
409 | self.seed = seed or torch.random.seed() | ||
410 | self.sample_batches = sample_batches | ||
411 | self.sample_batch_size = sample_batch_size | ||
412 | 394 | ||
413 | @torch.no_grad() | 395 | @torch.no_grad() |
414 | def checkpoint(self, step, postfix): | 396 | def checkpoint(self, step, postfix): |
@@ -431,9 +413,7 @@ class Checkpointer: | |||
431 | del learned_embeds | 413 | del learned_embeds |
432 | 414 | ||
433 | @torch.no_grad() | 415 | @torch.no_grad() |
434 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 416 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
435 | samples_path = Path(self.output_dir).joinpath("samples") | ||
436 | |||
437 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 417 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
438 | 418 | ||
439 | # Save a sample image | 419 | # Save a sample image |
@@ -446,71 +426,10 @@ class Checkpointer: | |||
446 | ).to(self.accelerator.device) | 426 | ).to(self.accelerator.device) |
447 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 427 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
448 | 428 | ||
449 | train_data = self.datamodule.train_dataloader() | 429 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
450 | val_data = self.datamodule.val_dataloader() | ||
451 | |||
452 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
453 | stable_latents = torch.randn( | ||
454 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | ||
455 | device=pipeline.device, | ||
456 | generator=generator, | ||
457 | ) | ||
458 | |||
459 | with torch.autocast("cuda"), torch.inference_mode(): | ||
460 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
461 | all_samples = [] | ||
462 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
463 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
464 | |||
465 | data_enum = enumerate(data) | ||
466 | |||
467 | batches = [ | ||
468 | batch | ||
469 | for j, batch in data_enum | ||
470 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
471 | ] | ||
472 | prompts = [ | ||
473 | prompt.format(identifier=self.instance_identifier) | ||
474 | for batch in batches | ||
475 | for prompt in batch["prompts"] | ||
476 | ] | ||
477 | nprompts = [ | ||
478 | prompt | ||
479 | for batch in batches | ||
480 | for prompt in batch["nprompts"] | ||
481 | ] | ||
482 | |||
483 | for i in range(self.sample_batches): | ||
484 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
485 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
486 | |||
487 | samples = pipeline( | ||
488 | prompt=prompt, | ||
489 | negative_prompt=nprompt, | ||
490 | height=self.sample_image_size, | ||
491 | width=self.sample_image_size, | ||
492 | image=latents[:len(prompt)] if latents is not None else None, | ||
493 | generator=generator if latents is not None else None, | ||
494 | guidance_scale=guidance_scale, | ||
495 | eta=eta, | ||
496 | num_inference_steps=num_inference_steps, | ||
497 | output_type='pil' | ||
498 | ).images | ||
499 | |||
500 | all_samples += samples | ||
501 | |||
502 | del samples | ||
503 | |||
504 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
505 | image_grid.save(file_path, quality=85) | ||
506 | |||
507 | del all_samples | ||
508 | del image_grid | ||
509 | 430 | ||
510 | del text_encoder | 431 | del text_encoder |
511 | del pipeline | 432 | del pipeline |
512 | del generator | ||
513 | del stable_latents | ||
514 | 433 | ||
515 | if torch.cuda.is_available(): | 434 | if torch.cuda.is_available(): |
516 | torch.cuda.empty_cache() | 435 | torch.cuda.empty_cache() |
@@ -814,7 +733,14 @@ def main(): | |||
814 | # Only show the progress bar once on each machine. | 733 | # Only show the progress bar once on each machine. |
815 | 734 | ||
816 | global_step = 0 | 735 | global_step = 0 |
817 | min_val_loss = np.inf | 736 | |
737 | avg_loss = AverageMeter() | ||
738 | avg_acc = AverageMeter() | ||
739 | |||
740 | avg_loss_val = AverageMeter() | ||
741 | avg_acc_val = AverageMeter() | ||
742 | |||
743 | max_acc_val = 0.0 | ||
818 | 744 | ||
819 | checkpointer = Checkpointer( | 745 | checkpointer = Checkpointer( |
820 | datamodule=datamodule, | 746 | datamodule=datamodule, |
@@ -835,9 +761,7 @@ def main(): | |||
835 | ) | 761 | ) |
836 | 762 | ||
837 | if accelerator.is_main_process: | 763 | if accelerator.is_main_process: |
838 | checkpointer.save_samples( | 764 | checkpointer.save_samples(global_step_offset, args.sample_steps) |
839 | 0, | ||
840 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
841 | 765 | ||
842 | local_progress_bar = tqdm( | 766 | local_progress_bar = tqdm( |
843 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 767 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
@@ -910,6 +834,8 @@ def main(): | |||
910 | else: | 834 | else: |
911 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 835 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
912 | 836 | ||
837 | acc = (model_pred == latents).float().mean() | ||
838 | |||
913 | accelerator.backward(loss) | 839 | accelerator.backward(loss) |
914 | 840 | ||
915 | optimizer.step() | 841 | optimizer.step() |
@@ -922,8 +848,8 @@ def main(): | |||
922 | text_encoder.get_input_embeddings( | 848 | text_encoder.get_input_embeddings( |
923 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | 849 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] |
924 | 850 | ||
925 | loss = loss.detach().item() | 851 | avg_loss.update(loss.detach_(), bsz) |
926 | train_loss += loss | 852 | avg_acc.update(acc.detach_(), bsz) |
927 | 853 | ||
928 | # Checks if the accelerator has performed an optimization step behind the scenes | 854 | # Checks if the accelerator has performed an optimization step behind the scenes |
929 | if accelerator.sync_gradients: | 855 | if accelerator.sync_gradients: |
@@ -932,7 +858,13 @@ def main(): | |||
932 | 858 | ||
933 | global_step += 1 | 859 | global_step += 1 |
934 | 860 | ||
935 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 861 | logs = { |
862 | "train/loss": avg_loss.avg.item(), | ||
863 | "train/acc": avg_acc.avg.item(), | ||
864 | "train/cur_loss": loss.item(), | ||
865 | "train/cur_acc": acc.item(), | ||
866 | "lr": lr_scheduler.get_last_lr()[0], | ||
867 | } | ||
936 | 868 | ||
937 | accelerator.log(logs, step=global_step) | 869 | accelerator.log(logs, step=global_step) |
938 | 870 | ||
@@ -941,12 +873,9 @@ def main(): | |||
941 | if global_step >= args.max_train_steps: | 873 | if global_step >= args.max_train_steps: |
942 | break | 874 | break |
943 | 875 | ||
944 | train_loss /= len(train_dataloader) | ||
945 | |||
946 | accelerator.wait_for_everyone() | 876 | accelerator.wait_for_everyone() |
947 | 877 | ||
948 | text_encoder.eval() | 878 | text_encoder.eval() |
949 | val_loss = 0.0 | ||
950 | 879 | ||
951 | with torch.inference_mode(): | 880 | with torch.inference_mode(): |
952 | for step, batch in enumerate(val_dataloader): | 881 | for step, batch in enumerate(val_dataloader): |
@@ -976,29 +905,37 @@ def main(): | |||
976 | 905 | ||
977 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 906 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
978 | 907 | ||
979 | loss = loss.detach().item() | 908 | acc = (model_pred == latents).float().mean() |
980 | val_loss += loss | 909 | |
910 | avg_loss_val.update(loss.detach_(), bsz) | ||
911 | avg_acc_val.update(acc.detach_(), bsz) | ||
981 | 912 | ||
982 | if accelerator.sync_gradients: | 913 | if accelerator.sync_gradients: |
983 | local_progress_bar.update(1) | 914 | local_progress_bar.update(1) |
984 | global_progress_bar.update(1) | 915 | global_progress_bar.update(1) |
985 | 916 | ||
986 | logs = {"val/loss": loss} | 917 | logs = { |
918 | "val/loss": avg_loss_val.avg.item(), | ||
919 | "val/acc": avg_acc_val.avg.item(), | ||
920 | "val/cur_loss": loss.item(), | ||
921 | "val/cur_acc": acc.item(), | ||
922 | } | ||
987 | local_progress_bar.set_postfix(**logs) | 923 | local_progress_bar.set_postfix(**logs) |
988 | 924 | ||
989 | val_loss /= len(val_dataloader) | 925 | accelerator.log({ |
990 | 926 | "val/loss": avg_loss_val.avg.item(), | |
991 | accelerator.log({"val/loss": val_loss}, step=global_step) | 927 | "val/acc": avg_acc_val.avg.item(), |
928 | }, step=global_step) | ||
992 | 929 | ||
993 | local_progress_bar.clear() | 930 | local_progress_bar.clear() |
994 | global_progress_bar.clear() | 931 | global_progress_bar.clear() |
995 | 932 | ||
996 | if accelerator.is_main_process: | 933 | if accelerator.is_main_process: |
997 | if min_val_loss > val_loss: | 934 | if avg_acc_val.avg.item() > max_acc_val: |
998 | accelerator.print( | 935 | accelerator.print( |
999 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 936 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
1000 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 937 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
1001 | min_val_loss = val_loss | 938 | max_acc_val = avg_acc_val.avg.item() |
1002 | 939 | ||
1003 | if (epoch + 1) % args.checkpoint_frequency == 0: | 940 | if (epoch + 1) % args.checkpoint_frequency == 0: |
1004 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 941 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
@@ -1007,9 +944,7 @@ def main(): | |||
1007 | }) | 944 | }) |
1008 | 945 | ||
1009 | if (epoch + 1) % args.sample_frequency == 0: | 946 | if (epoch + 1) % args.sample_frequency == 0: |
1010 | checkpointer.save_samples( | 947 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) |
1011 | global_step + global_step_offset, | ||
1012 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
1013 | 948 | ||
1014 | # Create the pipeline using using the trained modules and save it. | 949 | # Create the pipeline using using the trained modules and save it. |
1015 | if accelerator.is_main_process: | 950 | if accelerator.is_main_process: |