summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
commit68540b27849564994d921968a36faa9b997e626d (patch)
tree8fbe834ab4c52f057cd114bbb0e786158f215acc /train_ti.py
parentFix training (diff)
downloadtextual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip
Moved common training code into separate module
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py175
1 files changed, 55 insertions, 120 deletions
diff --git a/train_ti.py b/train_ti.py
index 9616db6..198cf37 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -7,7 +7,6 @@ import logging
7import json 7import json
8from pathlib import Path 8from pathlib import Path
9 9
10import numpy as np
11import torch 10import torch
12import torch.nn.functional as F 11import torch.nn.functional as F
13import torch.utils.checkpoint 12import torch.utils.checkpoint
@@ -17,7 +16,6 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 18from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image
21from tqdm.auto import tqdm 19from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 21from slugify import slugify
@@ -26,6 +24,7 @@ from common import load_text_embeddings, load_text_embedding
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
28from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args
29from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
30 29
31logger = get_logger(__name__) 30logger = get_logger(__name__)
@@ -138,7 +137,7 @@ def parse_args():
138 parser.add_argument( 137 parser.add_argument(
139 "--tag_dropout", 138 "--tag_dropout",
140 type=float, 139 type=float,
141 default=0, 140 default=0.1,
142 help="Tag dropout probability.", 141 help="Tag dropout probability.",
143 ) 142 )
144 parser.add_argument( 143 parser.add_argument(
@@ -355,27 +354,7 @@ def parse_args():
355 return args 354 return args
356 355
357 356
358def freeze_params(params): 357class Checkpointer(CheckpointerBase):
359 for param in params:
360 param.requires_grad = False
361
362
363def save_args(basepath: Path, args, extra={}):
364 info = {"args": vars(args)}
365 info["args"].update(extra)
366 with open(basepath.joinpath("args.json"), "w") as f:
367 json.dump(info, f, indent=4)
368
369
370def make_grid(images, rows, cols):
371 w, h = images[0].size
372 grid = Image.new('RGB', size=(cols*w, rows*h))
373 for i, image in enumerate(images):
374 grid.paste(image, box=(i % cols*w, i//cols*h))
375 return grid
376
377
378class Checkpointer:
379 def __init__( 358 def __init__(
380 self, 359 self,
381 datamodule, 360 datamodule,
@@ -394,21 +373,24 @@ class Checkpointer:
394 sample_batch_size, 373 sample_batch_size,
395 seed 374 seed
396 ): 375 ):
397 self.datamodule = datamodule 376 super().__init__(
377 datamodule=datamodule,
378 output_dir=output_dir,
379 instance_identifier=instance_identifier,
380 placeholder_token=placeholder_token,
381 placeholder_token_id=placeholder_token_id,
382 sample_image_size=sample_image_size,
383 seed=seed or torch.random.seed(),
384 sample_batches=sample_batches,
385 sample_batch_size=sample_batch_size
386 )
387
398 self.accelerator = accelerator 388 self.accelerator = accelerator
399 self.vae = vae 389 self.vae = vae
400 self.unet = unet 390 self.unet = unet
401 self.tokenizer = tokenizer 391 self.tokenizer = tokenizer
402 self.text_encoder = text_encoder 392 self.text_encoder = text_encoder
403 self.scheduler = scheduler 393 self.scheduler = scheduler
404 self.instance_identifier = instance_identifier
405 self.placeholder_token = placeholder_token
406 self.placeholder_token_id = placeholder_token_id
407 self.output_dir = output_dir
408 self.sample_image_size = sample_image_size
409 self.seed = seed or torch.random.seed()
410 self.sample_batches = sample_batches
411 self.sample_batch_size = sample_batch_size
412 394
413 @torch.no_grad() 395 @torch.no_grad()
414 def checkpoint(self, step, postfix): 396 def checkpoint(self, step, postfix):
@@ -431,9 +413,7 @@ class Checkpointer:
431 del learned_embeds 413 del learned_embeds
432 414
433 @torch.no_grad() 415 @torch.no_grad()
434 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 416 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
435 samples_path = Path(self.output_dir).joinpath("samples")
436
437 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 417 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
438 418
439 # Save a sample image 419 # Save a sample image
@@ -446,71 +426,10 @@ class Checkpointer:
446 ).to(self.accelerator.device) 426 ).to(self.accelerator.device)
447 pipeline.set_progress_bar_config(dynamic_ncols=True) 427 pipeline.set_progress_bar_config(dynamic_ncols=True)
448 428
449 train_data = self.datamodule.train_dataloader() 429 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
450 val_data = self.datamodule.val_dataloader()
451
452 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
453 stable_latents = torch.randn(
454 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
455 device=pipeline.device,
456 generator=generator,
457 )
458
459 with torch.autocast("cuda"), torch.inference_mode():
460 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
461 all_samples = []
462 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
463 file_path.parent.mkdir(parents=True, exist_ok=True)
464
465 data_enum = enumerate(data)
466
467 batches = [
468 batch
469 for j, batch in data_enum
470 if j * data.batch_size < self.sample_batch_size * self.sample_batches
471 ]
472 prompts = [
473 prompt.format(identifier=self.instance_identifier)
474 for batch in batches
475 for prompt in batch["prompts"]
476 ]
477 nprompts = [
478 prompt
479 for batch in batches
480 for prompt in batch["nprompts"]
481 ]
482
483 for i in range(self.sample_batches):
484 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
485 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
486
487 samples = pipeline(
488 prompt=prompt,
489 negative_prompt=nprompt,
490 height=self.sample_image_size,
491 width=self.sample_image_size,
492 image=latents[:len(prompt)] if latents is not None else None,
493 generator=generator if latents is not None else None,
494 guidance_scale=guidance_scale,
495 eta=eta,
496 num_inference_steps=num_inference_steps,
497 output_type='pil'
498 ).images
499
500 all_samples += samples
501
502 del samples
503
504 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
505 image_grid.save(file_path, quality=85)
506
507 del all_samples
508 del image_grid
509 430
510 del text_encoder 431 del text_encoder
511 del pipeline 432 del pipeline
512 del generator
513 del stable_latents
514 433
515 if torch.cuda.is_available(): 434 if torch.cuda.is_available():
516 torch.cuda.empty_cache() 435 torch.cuda.empty_cache()
@@ -814,7 +733,14 @@ def main():
814 # Only show the progress bar once on each machine. 733 # Only show the progress bar once on each machine.
815 734
816 global_step = 0 735 global_step = 0
817 min_val_loss = np.inf 736
737 avg_loss = AverageMeter()
738 avg_acc = AverageMeter()
739
740 avg_loss_val = AverageMeter()
741 avg_acc_val = AverageMeter()
742
743 max_acc_val = 0.0
818 744
819 checkpointer = Checkpointer( 745 checkpointer = Checkpointer(
820 datamodule=datamodule, 746 datamodule=datamodule,
@@ -835,9 +761,7 @@ def main():
835 ) 761 )
836 762
837 if accelerator.is_main_process: 763 if accelerator.is_main_process:
838 checkpointer.save_samples( 764 checkpointer.save_samples(global_step_offset, args.sample_steps)
839 0,
840 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
841 765
842 local_progress_bar = tqdm( 766 local_progress_bar = tqdm(
843 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 767 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
@@ -910,6 +834,8 @@ def main():
910 else: 834 else:
911 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 835 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
912 836
837 acc = (model_pred == latents).float().mean()
838
913 accelerator.backward(loss) 839 accelerator.backward(loss)
914 840
915 optimizer.step() 841 optimizer.step()
@@ -922,8 +848,8 @@ def main():
922 text_encoder.get_input_embeddings( 848 text_encoder.get_input_embeddings(
923 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] 849 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
924 850
925 loss = loss.detach().item() 851 avg_loss.update(loss.detach_(), bsz)
926 train_loss += loss 852 avg_acc.update(acc.detach_(), bsz)
927 853
928 # Checks if the accelerator has performed an optimization step behind the scenes 854 # Checks if the accelerator has performed an optimization step behind the scenes
929 if accelerator.sync_gradients: 855 if accelerator.sync_gradients:
@@ -932,7 +858,13 @@ def main():
932 858
933 global_step += 1 859 global_step += 1
934 860
935 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 861 logs = {
862 "train/loss": avg_loss.avg.item(),
863 "train/acc": avg_acc.avg.item(),
864 "train/cur_loss": loss.item(),
865 "train/cur_acc": acc.item(),
866 "lr": lr_scheduler.get_last_lr()[0],
867 }
936 868
937 accelerator.log(logs, step=global_step) 869 accelerator.log(logs, step=global_step)
938 870
@@ -941,12 +873,9 @@ def main():
941 if global_step >= args.max_train_steps: 873 if global_step >= args.max_train_steps:
942 break 874 break
943 875
944 train_loss /= len(train_dataloader)
945
946 accelerator.wait_for_everyone() 876 accelerator.wait_for_everyone()
947 877
948 text_encoder.eval() 878 text_encoder.eval()
949 val_loss = 0.0
950 879
951 with torch.inference_mode(): 880 with torch.inference_mode():
952 for step, batch in enumerate(val_dataloader): 881 for step, batch in enumerate(val_dataloader):
@@ -976,29 +905,37 @@ def main():
976 905
977 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 906 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
978 907
979 loss = loss.detach().item() 908 acc = (model_pred == latents).float().mean()
980 val_loss += loss 909
910 avg_loss_val.update(loss.detach_(), bsz)
911 avg_acc_val.update(acc.detach_(), bsz)
981 912
982 if accelerator.sync_gradients: 913 if accelerator.sync_gradients:
983 local_progress_bar.update(1) 914 local_progress_bar.update(1)
984 global_progress_bar.update(1) 915 global_progress_bar.update(1)
985 916
986 logs = {"val/loss": loss} 917 logs = {
918 "val/loss": avg_loss_val.avg.item(),
919 "val/acc": avg_acc_val.avg.item(),
920 "val/cur_loss": loss.item(),
921 "val/cur_acc": acc.item(),
922 }
987 local_progress_bar.set_postfix(**logs) 923 local_progress_bar.set_postfix(**logs)
988 924
989 val_loss /= len(val_dataloader) 925 accelerator.log({
990 926 "val/loss": avg_loss_val.avg.item(),
991 accelerator.log({"val/loss": val_loss}, step=global_step) 927 "val/acc": avg_acc_val.avg.item(),
928 }, step=global_step)
992 929
993 local_progress_bar.clear() 930 local_progress_bar.clear()
994 global_progress_bar.clear() 931 global_progress_bar.clear()
995 932
996 if accelerator.is_main_process: 933 if accelerator.is_main_process:
997 if min_val_loss > val_loss: 934 if avg_acc_val.avg.item() > max_acc_val:
998 accelerator.print( 935 accelerator.print(
999 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 936 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1000 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 937 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
1001 min_val_loss = val_loss 938 max_acc_val = avg_acc_val.avg.item()
1002 939
1003 if (epoch + 1) % args.checkpoint_frequency == 0: 940 if (epoch + 1) % args.checkpoint_frequency == 0:
1004 checkpointer.checkpoint(global_step + global_step_offset, "training") 941 checkpointer.checkpoint(global_step + global_step_offset, "training")
@@ -1007,9 +944,7 @@ def main():
1007 }) 944 })
1008 945
1009 if (epoch + 1) % args.sample_frequency == 0: 946 if (epoch + 1) % args.sample_frequency == 0:
1010 checkpointer.save_samples( 947 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
1011 global_step + global_step_offset,
1012 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1013 948
1014 # Create the pipeline using using the trained modules and save it. 949 # Create the pipeline using using the trained modules and save it.
1015 if accelerator.is_main_process: 950 if accelerator.is_main_process: