summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
commit68540b27849564994d921968a36faa9b997e626d (patch)
tree8fbe834ab4c52f057cd114bbb0e786158f215acc
parentFix training (diff)
downloadtextual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip
Moved common training code into separate module
-rw-r--r--train_dreambooth.py126
-rw-r--r--train_ti.py175
-rw-r--r--training/optimization.py2
-rw-r--r--training/util.py131
4 files changed, 203 insertions, 231 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0f8fece..9749c62 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from PIL import Image
20from tqdm.auto import tqdm 19from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify 21from slugify import slugify
@@ -25,6 +24,7 @@ from common import load_text_embeddings
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import CSVDataModule 25from data.csv import CSVDataModule
27from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args
28from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
29 29
30logger = get_logger(__name__) 30logger = get_logger(__name__)
@@ -385,41 +385,7 @@ def parse_args():
385 return args 385 return args
386 386
387 387
388def save_args(basepath: Path, args, extra={}): 388class Checkpointer(CheckpointerBase):
389 info = {"args": vars(args)}
390 info["args"].update(extra)
391 with open(basepath.joinpath("args.json"), "w") as f:
392 json.dump(info, f, indent=4)
393
394
395def freeze_params(params):
396 for param in params:
397 param.requires_grad = False
398
399
400def make_grid(images, rows, cols):
401 w, h = images[0].size
402 grid = Image.new('RGB', size=(cols*w, rows*h))
403 for i, image in enumerate(images):
404 grid.paste(image, box=(i % cols*w, i//cols*h))
405 return grid
406
407
408class AverageMeter:
409 def __init__(self, name=None):
410 self.name = name
411 self.reset()
412
413 def reset(self):
414 self.sum = self.count = self.avg = 0
415
416 def update(self, val, n=1):
417 self.sum += val * n
418 self.count += n
419 self.avg = self.sum / self.count
420
421
422class Checkpointer:
423 def __init__( 389 def __init__(
424 self, 390 self,
425 datamodule, 391 datamodule,
@@ -437,9 +403,20 @@ class Checkpointer:
437 sample_image_size, 403 sample_image_size,
438 sample_batches, 404 sample_batches,
439 sample_batch_size, 405 sample_batch_size,
440 seed 406 seed,
441 ): 407 ):
442 self.datamodule = datamodule 408 super().__init__(
409 datamodule=datamodule,
410 output_dir=output_dir,
411 instance_identifier=instance_identifier,
412 placeholder_token=placeholder_token,
413 placeholder_token_id=placeholder_token_id,
414 sample_image_size=sample_image_size,
415 seed=seed or torch.random.seed(),
416 sample_batches=sample_batches,
417 sample_batch_size=sample_batch_size
418 )
419
443 self.accelerator = accelerator 420 self.accelerator = accelerator
444 self.vae = vae 421 self.vae = vae
445 self.unet = unet 422 self.unet = unet
@@ -447,14 +424,6 @@ class Checkpointer:
447 self.tokenizer = tokenizer 424 self.tokenizer = tokenizer
448 self.text_encoder = text_encoder 425 self.text_encoder = text_encoder
449 self.scheduler = scheduler 426 self.scheduler = scheduler
450 self.output_dir = output_dir
451 self.instance_identifier = instance_identifier
452 self.placeholder_token = placeholder_token
453 self.placeholder_token_id = placeholder_token_id
454 self.sample_image_size = sample_image_size
455 self.seed = seed or torch.random.seed()
456 self.sample_batches = sample_batches
457 self.sample_batch_size = sample_batch_size
458 427
459 @torch.no_grad() 428 @torch.no_grad()
460 def save_model(self): 429 def save_model(self):
@@ -481,8 +450,6 @@ class Checkpointer:
481 450
482 @torch.no_grad() 451 @torch.no_grad()
483 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 452 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
484 samples_path = Path(self.output_dir).joinpath("samples")
485
486 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) 453 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
487 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 454 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
488 455
@@ -495,72 +462,11 @@ class Checkpointer:
495 ).to(self.accelerator.device) 462 ).to(self.accelerator.device)
496 pipeline.set_progress_bar_config(dynamic_ncols=True) 463 pipeline.set_progress_bar_config(dynamic_ncols=True)
497 464
498 train_data = self.datamodule.train_dataloader() 465 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
499 val_data = self.datamodule.val_dataloader()
500
501 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
502 stable_latents = torch.randn(
503 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
504 device=pipeline.device,
505 generator=generator,
506 )
507
508 with torch.autocast("cuda"), torch.inference_mode():
509 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
510 all_samples = []
511 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
512 file_path.parent.mkdir(parents=True, exist_ok=True)
513
514 data_enum = enumerate(data)
515
516 batches = [
517 batch
518 for j, batch in data_enum
519 if j * data.batch_size < self.sample_batch_size * self.sample_batches
520 ]
521 prompts = [
522 prompt.format(identifier=self.instance_identifier)
523 for batch in batches
524 for prompt in batch["prompts"]
525 ]
526 nprompts = [
527 prompt
528 for batch in batches
529 for prompt in batch["nprompts"]
530 ]
531
532 for i in range(self.sample_batches):
533 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
534 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
535
536 samples = pipeline(
537 prompt=prompt,
538 negative_prompt=nprompt,
539 height=self.sample_image_size,
540 width=self.sample_image_size,
541 image=latents[:len(prompt)] if latents is not None else None,
542 generator=generator if latents is not None else None,
543 guidance_scale=guidance_scale,
544 eta=eta,
545 num_inference_steps=num_inference_steps,
546 output_type='pil'
547 ).images
548
549 all_samples += samples
550
551 del samples
552
553 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
554 image_grid.save(file_path, quality=85)
555
556 del all_samples
557 del image_grid
558 466
559 del unet 467 del unet
560 del text_encoder 468 del text_encoder
561 del pipeline 469 del pipeline
562 del generator
563 del stable_latents
564 470
565 if torch.cuda.is_available(): 471 if torch.cuda.is_available():
566 torch.cuda.empty_cache() 472 torch.cuda.empty_cache()
diff --git a/train_ti.py b/train_ti.py
index 9616db6..198cf37 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -7,7 +7,6 @@ import logging
7import json 7import json
8from pathlib import Path 8from pathlib import Path
9 9
10import numpy as np
11import torch 10import torch
12import torch.nn.functional as F 11import torch.nn.functional as F
13import torch.utils.checkpoint 12import torch.utils.checkpoint
@@ -17,7 +16,6 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 18from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image
21from tqdm.auto import tqdm 19from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 21from slugify import slugify
@@ -26,6 +24,7 @@ from common import load_text_embeddings, load_text_embedding
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
28from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args
29from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
30 29
31logger = get_logger(__name__) 30logger = get_logger(__name__)
@@ -138,7 +137,7 @@ def parse_args():
138 parser.add_argument( 137 parser.add_argument(
139 "--tag_dropout", 138 "--tag_dropout",
140 type=float, 139 type=float,
141 default=0, 140 default=0.1,
142 help="Tag dropout probability.", 141 help="Tag dropout probability.",
143 ) 142 )
144 parser.add_argument( 143 parser.add_argument(
@@ -355,27 +354,7 @@ def parse_args():
355 return args 354 return args
356 355
357 356
358def freeze_params(params): 357class Checkpointer(CheckpointerBase):
359 for param in params:
360 param.requires_grad = False
361
362
363def save_args(basepath: Path, args, extra={}):
364 info = {"args": vars(args)}
365 info["args"].update(extra)
366 with open(basepath.joinpath("args.json"), "w") as f:
367 json.dump(info, f, indent=4)
368
369
370def make_grid(images, rows, cols):
371 w, h = images[0].size
372 grid = Image.new('RGB', size=(cols*w, rows*h))
373 for i, image in enumerate(images):
374 grid.paste(image, box=(i % cols*w, i//cols*h))
375 return grid
376
377
378class Checkpointer:
379 def __init__( 358 def __init__(
380 self, 359 self,
381 datamodule, 360 datamodule,
@@ -394,21 +373,24 @@ class Checkpointer:
394 sample_batch_size, 373 sample_batch_size,
395 seed 374 seed
396 ): 375 ):
397 self.datamodule = datamodule 376 super().__init__(
377 datamodule=datamodule,
378 output_dir=output_dir,
379 instance_identifier=instance_identifier,
380 placeholder_token=placeholder_token,
381 placeholder_token_id=placeholder_token_id,
382 sample_image_size=sample_image_size,
383 seed=seed or torch.random.seed(),
384 sample_batches=sample_batches,
385 sample_batch_size=sample_batch_size
386 )
387
398 self.accelerator = accelerator 388 self.accelerator = accelerator
399 self.vae = vae 389 self.vae = vae
400 self.unet = unet 390 self.unet = unet
401 self.tokenizer = tokenizer 391 self.tokenizer = tokenizer
402 self.text_encoder = text_encoder 392 self.text_encoder = text_encoder
403 self.scheduler = scheduler 393 self.scheduler = scheduler
404 self.instance_identifier = instance_identifier
405 self.placeholder_token = placeholder_token
406 self.placeholder_token_id = placeholder_token_id
407 self.output_dir = output_dir
408 self.sample_image_size = sample_image_size
409 self.seed = seed or torch.random.seed()
410 self.sample_batches = sample_batches
411 self.sample_batch_size = sample_batch_size
412 394
413 @torch.no_grad() 395 @torch.no_grad()
414 def checkpoint(self, step, postfix): 396 def checkpoint(self, step, postfix):
@@ -431,9 +413,7 @@ class Checkpointer:
431 del learned_embeds 413 del learned_embeds
432 414
433 @torch.no_grad() 415 @torch.no_grad()
434 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 416 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
435 samples_path = Path(self.output_dir).joinpath("samples")
436
437 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 417 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
438 418
439 # Save a sample image 419 # Save a sample image
@@ -446,71 +426,10 @@ class Checkpointer:
446 ).to(self.accelerator.device) 426 ).to(self.accelerator.device)
447 pipeline.set_progress_bar_config(dynamic_ncols=True) 427 pipeline.set_progress_bar_config(dynamic_ncols=True)
448 428
449 train_data = self.datamodule.train_dataloader() 429 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
450 val_data = self.datamodule.val_dataloader()
451
452 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
453 stable_latents = torch.randn(
454 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
455 device=pipeline.device,
456 generator=generator,
457 )
458
459 with torch.autocast("cuda"), torch.inference_mode():
460 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
461 all_samples = []
462 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
463 file_path.parent.mkdir(parents=True, exist_ok=True)
464
465 data_enum = enumerate(data)
466
467 batches = [
468 batch
469 for j, batch in data_enum
470 if j * data.batch_size < self.sample_batch_size * self.sample_batches
471 ]
472 prompts = [
473 prompt.format(identifier=self.instance_identifier)
474 for batch in batches
475 for prompt in batch["prompts"]
476 ]
477 nprompts = [
478 prompt
479 for batch in batches
480 for prompt in batch["nprompts"]
481 ]
482
483 for i in range(self.sample_batches):
484 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
485 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
486
487 samples = pipeline(
488 prompt=prompt,
489 negative_prompt=nprompt,
490 height=self.sample_image_size,
491 width=self.sample_image_size,
492 image=latents[:len(prompt)] if latents is not None else None,
493 generator=generator if latents is not None else None,
494 guidance_scale=guidance_scale,
495 eta=eta,
496 num_inference_steps=num_inference_steps,
497 output_type='pil'
498 ).images
499
500 all_samples += samples
501
502 del samples
503
504 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
505 image_grid.save(file_path, quality=85)
506
507 del all_samples
508 del image_grid
509 430
510 del text_encoder 431 del text_encoder
511 del pipeline 432 del pipeline
512 del generator
513 del stable_latents
514 433
515 if torch.cuda.is_available(): 434 if torch.cuda.is_available():
516 torch.cuda.empty_cache() 435 torch.cuda.empty_cache()
@@ -814,7 +733,14 @@ def main():
814 # Only show the progress bar once on each machine. 733 # Only show the progress bar once on each machine.
815 734
816 global_step = 0 735 global_step = 0
817 min_val_loss = np.inf 736
737 avg_loss = AverageMeter()
738 avg_acc = AverageMeter()
739
740 avg_loss_val = AverageMeter()
741 avg_acc_val = AverageMeter()
742
743 max_acc_val = 0.0
818 744
819 checkpointer = Checkpointer( 745 checkpointer = Checkpointer(
820 datamodule=datamodule, 746 datamodule=datamodule,
@@ -835,9 +761,7 @@ def main():
835 ) 761 )
836 762
837 if accelerator.is_main_process: 763 if accelerator.is_main_process:
838 checkpointer.save_samples( 764 checkpointer.save_samples(global_step_offset, args.sample_steps)
839 0,
840 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
841 765
842 local_progress_bar = tqdm( 766 local_progress_bar = tqdm(
843 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 767 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
@@ -910,6 +834,8 @@ def main():
910 else: 834 else:
911 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 835 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
912 836
837 acc = (model_pred == latents).float().mean()
838
913 accelerator.backward(loss) 839 accelerator.backward(loss)
914 840
915 optimizer.step() 841 optimizer.step()
@@ -922,8 +848,8 @@ def main():
922 text_encoder.get_input_embeddings( 848 text_encoder.get_input_embeddings(
923 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] 849 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
924 850
925 loss = loss.detach().item() 851 avg_loss.update(loss.detach_(), bsz)
926 train_loss += loss 852 avg_acc.update(acc.detach_(), bsz)
927 853
928 # Checks if the accelerator has performed an optimization step behind the scenes 854 # Checks if the accelerator has performed an optimization step behind the scenes
929 if accelerator.sync_gradients: 855 if accelerator.sync_gradients:
@@ -932,7 +858,13 @@ def main():
932 858
933 global_step += 1 859 global_step += 1
934 860
935 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 861 logs = {
862 "train/loss": avg_loss.avg.item(),
863 "train/acc": avg_acc.avg.item(),
864 "train/cur_loss": loss.item(),
865 "train/cur_acc": acc.item(),
866 "lr": lr_scheduler.get_last_lr()[0],
867 }
936 868
937 accelerator.log(logs, step=global_step) 869 accelerator.log(logs, step=global_step)
938 870
@@ -941,12 +873,9 @@ def main():
941 if global_step >= args.max_train_steps: 873 if global_step >= args.max_train_steps:
942 break 874 break
943 875
944 train_loss /= len(train_dataloader)
945
946 accelerator.wait_for_everyone() 876 accelerator.wait_for_everyone()
947 877
948 text_encoder.eval() 878 text_encoder.eval()
949 val_loss = 0.0
950 879
951 with torch.inference_mode(): 880 with torch.inference_mode():
952 for step, batch in enumerate(val_dataloader): 881 for step, batch in enumerate(val_dataloader):
@@ -976,29 +905,37 @@ def main():
976 905
977 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 906 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
978 907
979 loss = loss.detach().item() 908 acc = (model_pred == latents).float().mean()
980 val_loss += loss 909
910 avg_loss_val.update(loss.detach_(), bsz)
911 avg_acc_val.update(acc.detach_(), bsz)
981 912
982 if accelerator.sync_gradients: 913 if accelerator.sync_gradients:
983 local_progress_bar.update(1) 914 local_progress_bar.update(1)
984 global_progress_bar.update(1) 915 global_progress_bar.update(1)
985 916
986 logs = {"val/loss": loss} 917 logs = {
918 "val/loss": avg_loss_val.avg.item(),
919 "val/acc": avg_acc_val.avg.item(),
920 "val/cur_loss": loss.item(),
921 "val/cur_acc": acc.item(),
922 }
987 local_progress_bar.set_postfix(**logs) 923 local_progress_bar.set_postfix(**logs)
988 924
989 val_loss /= len(val_dataloader) 925 accelerator.log({
990 926 "val/loss": avg_loss_val.avg.item(),
991 accelerator.log({"val/loss": val_loss}, step=global_step) 927 "val/acc": avg_acc_val.avg.item(),
928 }, step=global_step)
992 929
993 local_progress_bar.clear() 930 local_progress_bar.clear()
994 global_progress_bar.clear() 931 global_progress_bar.clear()
995 932
996 if accelerator.is_main_process: 933 if accelerator.is_main_process:
997 if min_val_loss > val_loss: 934 if avg_acc_val.avg.item() > max_acc_val:
998 accelerator.print( 935 accelerator.print(
999 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 936 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1000 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 937 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
1001 min_val_loss = val_loss 938 max_acc_val = avg_acc_val.avg.item()
1002 939
1003 if (epoch + 1) % args.checkpoint_frequency == 0: 940 if (epoch + 1) % args.checkpoint_frequency == 0:
1004 checkpointer.checkpoint(global_step + global_step_offset, "training") 941 checkpointer.checkpoint(global_step + global_step_offset, "training")
@@ -1007,9 +944,7 @@ def main():
1007 }) 944 })
1008 945
1009 if (epoch + 1) % args.sample_frequency == 0: 946 if (epoch + 1) % args.sample_frequency == 0:
1010 checkpointer.save_samples( 947 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
1011 global_step + global_step_offset,
1012 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1013 948
1014 # Create the pipeline using using the trained modules and save it. 949 # Create the pipeline using using the trained modules and save it.
1015 if accelerator.is_main_process: 950 if accelerator.is_main_process:
diff --git a/training/optimization.py b/training/optimization.py
index 0e603fa..c501ed9 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ from diffusers.utils import logging
6logger = logging.get_logger(__name__) 6logger = logging.get_logger(__name__)
7 7
8 8
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): 9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.001, mid_point=0.4, last_epoch=-1):
10 """ 10 """
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
diff --git a/training/util.py b/training/util.py
new file mode 100644
index 0000000..e8d22ae
--- /dev/null
+++ b/training/util.py
@@ -0,0 +1,131 @@
1from pathlib import Path
2import json
3
4import torch
5from PIL import Image
6
7
8def freeze_params(params):
9 for param in params:
10 param.requires_grad = False
11
12
13def save_args(basepath: Path, args, extra={}):
14 info = {"args": vars(args)}
15 info["args"].update(extra)
16 with open(basepath.joinpath("args.json"), "w") as f:
17 json.dump(info, f, indent=4)
18
19
20def make_grid(images, rows, cols):
21 w, h = images[0].size
22 grid = Image.new('RGB', size=(cols*w, rows*h))
23 for i, image in enumerate(images):
24 grid.paste(image, box=(i % cols*w, i//cols*h))
25 return grid
26
27
28class AverageMeter:
29 def __init__(self, name=None):
30 self.name = name
31 self.reset()
32
33 def reset(self):
34 self.sum = self.count = self.avg = 0
35
36 def update(self, val, n=1):
37 self.sum += val * n
38 self.count += n
39 self.avg = self.sum / self.count
40
41
42class CheckpointerBase:
43 def __init__(
44 self,
45 datamodule,
46 output_dir: Path,
47 instance_identifier,
48 placeholder_token,
49 placeholder_token_id,
50 sample_image_size,
51 sample_batches,
52 sample_batch_size,
53 seed
54 ):
55 self.datamodule = datamodule
56 self.output_dir = output_dir
57 self.instance_identifier = instance_identifier
58 self.placeholder_token = placeholder_token
59 self.placeholder_token_id = placeholder_token_id
60 self.sample_image_size = sample_image_size
61 self.seed = seed or torch.random.seed()
62 self.sample_batches = sample_batches
63 self.sample_batch_size = sample_batch_size
64
65 @torch.no_grad()
66 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
67 samples_path = Path(self.output_dir).joinpath("samples")
68
69 train_data = self.datamodule.train_dataloader()
70 val_data = self.datamodule.val_dataloader()
71
72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
73 stable_latents = torch.randn(
74 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
75 device=pipeline.device,
76 generator=generator,
77 )
78
79 with torch.autocast("cuda"), torch.inference_mode():
80 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
81 all_samples = []
82 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
83 file_path.parent.mkdir(parents=True, exist_ok=True)
84
85 data_enum = enumerate(data)
86
87 batches = [
88 batch
89 for j, batch in data_enum
90 if j * data.batch_size < self.sample_batch_size * self.sample_batches
91 ]
92 prompts = [
93 prompt.format(identifier=self.instance_identifier)
94 for batch in batches
95 for prompt in batch["prompts"]
96 ]
97 nprompts = [
98 prompt
99 for batch in batches
100 for prompt in batch["nprompts"]
101 ]
102
103 for i in range(self.sample_batches):
104 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
105 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
106
107 samples = pipeline(
108 prompt=prompt,
109 negative_prompt=nprompt,
110 height=self.sample_image_size,
111 width=self.sample_image_size,
112 image=latents[:len(prompt)] if latents is not None else None,
113 generator=generator if latents is not None else None,
114 guidance_scale=guidance_scale,
115 eta=eta,
116 num_inference_steps=num_inference_steps,
117 output_type='pil'
118 ).images
119
120 all_samples += samples
121
122 del samples
123
124 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
125 image_grid.save(file_path, quality=85)
126
127 del all_samples
128 del image_grid
129
130 del generator
131 del stable_latents