summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py74
1 files changed, 15 insertions, 59 deletions
diff --git a/train_ti.py b/train_ti.py
index 8631892..deed84c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -19,10 +19,11 @@ from slugify import slugify
19from util import load_config, load_embeddings_from_dir 19from util import load_config, load_embeddings_from_dir
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from data.csv import VlpnDataModule, VlpnDataItem 21from data.csv import VlpnDataModule, VlpnDataItem
22from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 22from trainer.base import Checkpointer
23from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
23from training.optimization import get_scheduler 24from training.optimization import get_scheduler
24from training.lr import LRFinder 25from training.lr import LRFinder
25from training.util import CheckpointerBase, EMAModel, save_args 26from training.util import EMAModel, save_args
26from models.clip.tokenizer import MultiCLIPTokenizer 27from models.clip.tokenizer import MultiCLIPTokenizer
27 28
28logger = get_logger(__name__) 29logger = get_logger(__name__)
@@ -480,38 +481,20 @@ def parse_args():
480 return args 481 return args
481 482
482 483
483class Checkpointer(CheckpointerBase): 484class TextualInversionCheckpointer(Checkpointer):
484 def __init__( 485 def __init__(
485 self, 486 self,
486 weight_dtype: torch.dtype,
487 accelerator: Accelerator,
488 vae: AutoencoderKL,
489 unet: UNet2DConditionModel,
490 tokenizer: MultiCLIPTokenizer,
491 text_encoder: CLIPTextModel,
492 ema_embeddings: EMAModel, 487 ema_embeddings: EMAModel,
493 scheduler,
494 placeholder_tokens,
495 placeholder_token_ids,
496 *args, 488 *args,
497 **kwargs 489 **kwargs,
498 ): 490 ):
499 super().__init__(*args, **kwargs) 491 super().__init__(*args, **kwargs)
500 492
501 self.weight_dtype = weight_dtype
502 self.accelerator = accelerator
503 self.vae = vae
504 self.unet = unet
505 self.tokenizer = tokenizer
506 self.text_encoder = text_encoder
507 self.ema_embeddings = ema_embeddings 493 self.ema_embeddings = ema_embeddings
508 self.scheduler = scheduler
509 self.placeholder_tokens = placeholder_tokens
510 self.placeholder_token_ids = placeholder_token_ids
511 494
512 @torch.no_grad() 495 @torch.no_grad()
513 def checkpoint(self, step, postfix): 496 def checkpoint(self, step, postfix):
514 print("Saving checkpoint for step %d..." % step) 497 print(f"Saving checkpoint for step {step}...")
515 498
516 checkpoints_path = self.output_dir.joinpath("checkpoints") 499 checkpoints_path = self.output_dir.joinpath("checkpoints")
517 checkpoints_path.mkdir(parents=True, exist_ok=True) 500 checkpoints_path.mkdir(parents=True, exist_ok=True)
@@ -519,7 +502,8 @@ class Checkpointer(CheckpointerBase):
519 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 502 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
520 503
521 ema_context = self.ema_embeddings.apply_temporary( 504 ema_context = self.ema_embeddings.apply_temporary(
522 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() 505 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
506 ) if self.ema_embeddings is not None else nullcontext()
523 507
524 with ema_context: 508 with ema_context:
525 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): 509 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
@@ -528,42 +512,14 @@ class Checkpointer(CheckpointerBase):
528 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 512 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
529 ) 513 )
530 514
531 del text_encoder 515 @torch.inference_mode()
532
533 @torch.no_grad()
534 def save_samples(self, step): 516 def save_samples(self, step):
535 unet = self.accelerator.unwrap_model(self.unet)
536 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
537
538 ema_context = self.ema_embeddings.apply_temporary( 517 ema_context = self.ema_embeddings.apply_temporary(
539 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() 518 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
519 ) if self.ema_embeddings is not None else nullcontext()
540 520
541 with ema_context: 521 with ema_context:
542 orig_unet_dtype = unet.dtype 522 super().save_samples(step)
543 orig_text_encoder_dtype = text_encoder.dtype
544
545 unet.to(dtype=self.weight_dtype)
546 text_encoder.to(dtype=self.weight_dtype)
547
548 pipeline = VlpnStableDiffusion(
549 text_encoder=text_encoder,
550 vae=self.vae,
551 unet=self.unet,
552 tokenizer=self.tokenizer,
553 scheduler=self.scheduler,
554 ).to(self.accelerator.device)
555 pipeline.set_progress_bar_config(dynamic_ncols=True)
556
557 super().save_samples(pipeline, step)
558
559 unet.to(dtype=orig_unet_dtype)
560 text_encoder.to(dtype=orig_text_encoder_dtype)
561
562 del text_encoder
563 del pipeline
564
565 if torch.cuda.is_available():
566 torch.cuda.empty_cache()
567 523
568 524
569def main(): 525def main():
@@ -806,8 +762,8 @@ def main():
806 args.seed, 762 args.seed,
807 ) 763 )
808 764
809 checkpointer = Checkpointer( 765 checkpointer = TextualInversionCheckpointer(
810 weight_dtype=weight_dtype, 766 dtype=weight_dtype,
811 train_dataloader=train_dataloader, 767 train_dataloader=train_dataloader,
812 val_dataloader=val_dataloader, 768 val_dataloader=val_dataloader,
813 accelerator=accelerator, 769 accelerator=accelerator,
@@ -816,7 +772,7 @@ def main():
816 tokenizer=tokenizer, 772 tokenizer=tokenizer,
817 text_encoder=text_encoder, 773 text_encoder=text_encoder,
818 ema_embeddings=ema_embeddings, 774 ema_embeddings=ema_embeddings,
819 scheduler=sample_scheduler, 775 sample_scheduler=sample_scheduler,
820 placeholder_tokens=args.placeholder_tokens, 776 placeholder_tokens=args.placeholder_tokens,
821 placeholder_token_ids=placeholder_token_ids, 777 placeholder_token_ids=placeholder_token_ids,
822 output_dir=output_dir, 778 output_dir=output_dir,