diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
commit | 83808fe00ac891ad2f625388d144c318b2cb5bfe (patch) | |
tree | b7ca19d27f90be6f02b14f4a39c62fc7250041a2 /train_ti.py | |
parent | TI: Prepare UNet with Accelerate as well (diff) | |
download | textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.gz textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.bz2 textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.zip |
WIP: Modularization ("free(): invalid pointer" my ass)
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 74 |
1 files changed, 15 insertions, 59 deletions
diff --git a/train_ti.py b/train_ti.py index 8631892..deed84c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -19,10 +19,11 @@ from slugify import slugify | |||
19 | from util import load_config, load_embeddings_from_dir | 19 | from util import load_config, load_embeddings_from_dir |
20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
21 | from data.csv import VlpnDataModule, VlpnDataItem | 21 | from data.csv import VlpnDataModule, VlpnDataItem |
22 | from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | 22 | from trainer.base import Checkpointer |
23 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | ||
23 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
24 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
25 | from training.util import CheckpointerBase, EMAModel, save_args | 26 | from training.util import EMAModel, save_args |
26 | from models.clip.tokenizer import MultiCLIPTokenizer | 27 | from models.clip.tokenizer import MultiCLIPTokenizer |
27 | 28 | ||
28 | logger = get_logger(__name__) | 29 | logger = get_logger(__name__) |
@@ -480,38 +481,20 @@ def parse_args(): | |||
480 | return args | 481 | return args |
481 | 482 | ||
482 | 483 | ||
483 | class Checkpointer(CheckpointerBase): | 484 | class TextualInversionCheckpointer(Checkpointer): |
484 | def __init__( | 485 | def __init__( |
485 | self, | 486 | self, |
486 | weight_dtype: torch.dtype, | ||
487 | accelerator: Accelerator, | ||
488 | vae: AutoencoderKL, | ||
489 | unet: UNet2DConditionModel, | ||
490 | tokenizer: MultiCLIPTokenizer, | ||
491 | text_encoder: CLIPTextModel, | ||
492 | ema_embeddings: EMAModel, | 487 | ema_embeddings: EMAModel, |
493 | scheduler, | ||
494 | placeholder_tokens, | ||
495 | placeholder_token_ids, | ||
496 | *args, | 488 | *args, |
497 | **kwargs | 489 | **kwargs, |
498 | ): | 490 | ): |
499 | super().__init__(*args, **kwargs) | 491 | super().__init__(*args, **kwargs) |
500 | 492 | ||
501 | self.weight_dtype = weight_dtype | ||
502 | self.accelerator = accelerator | ||
503 | self.vae = vae | ||
504 | self.unet = unet | ||
505 | self.tokenizer = tokenizer | ||
506 | self.text_encoder = text_encoder | ||
507 | self.ema_embeddings = ema_embeddings | 493 | self.ema_embeddings = ema_embeddings |
508 | self.scheduler = scheduler | ||
509 | self.placeholder_tokens = placeholder_tokens | ||
510 | self.placeholder_token_ids = placeholder_token_ids | ||
511 | 494 | ||
512 | @torch.no_grad() | 495 | @torch.no_grad() |
513 | def checkpoint(self, step, postfix): | 496 | def checkpoint(self, step, postfix): |
514 | print("Saving checkpoint for step %d..." % step) | 497 | print(f"Saving checkpoint for step {step}...") |
515 | 498 | ||
516 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 499 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
517 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 500 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
@@ -519,7 +502,8 @@ class Checkpointer(CheckpointerBase): | |||
519 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 502 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
520 | 503 | ||
521 | ema_context = self.ema_embeddings.apply_temporary( | 504 | ema_context = self.ema_embeddings.apply_temporary( |
522 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() | 505 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
506 | ) if self.ema_embeddings is not None else nullcontext() | ||
523 | 507 | ||
524 | with ema_context: | 508 | with ema_context: |
525 | for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): | 509 | for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): |
@@ -528,42 +512,14 @@ class Checkpointer(CheckpointerBase): | |||
528 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
529 | ) | 513 | ) |
530 | 514 | ||
531 | del text_encoder | 515 | @torch.inference_mode() |
532 | |||
533 | @torch.no_grad() | ||
534 | def save_samples(self, step): | 516 | def save_samples(self, step): |
535 | unet = self.accelerator.unwrap_model(self.unet) | ||
536 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
537 | |||
538 | ema_context = self.ema_embeddings.apply_temporary( | 517 | ema_context = self.ema_embeddings.apply_temporary( |
539 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() | 518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
519 | ) if self.ema_embeddings is not None else nullcontext() | ||
540 | 520 | ||
541 | with ema_context: | 521 | with ema_context: |
542 | orig_unet_dtype = unet.dtype | 522 | super().save_samples(step) |
543 | orig_text_encoder_dtype = text_encoder.dtype | ||
544 | |||
545 | unet.to(dtype=self.weight_dtype) | ||
546 | text_encoder.to(dtype=self.weight_dtype) | ||
547 | |||
548 | pipeline = VlpnStableDiffusion( | ||
549 | text_encoder=text_encoder, | ||
550 | vae=self.vae, | ||
551 | unet=self.unet, | ||
552 | tokenizer=self.tokenizer, | ||
553 | scheduler=self.scheduler, | ||
554 | ).to(self.accelerator.device) | ||
555 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
556 | |||
557 | super().save_samples(pipeline, step) | ||
558 | |||
559 | unet.to(dtype=orig_unet_dtype) | ||
560 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
561 | |||
562 | del text_encoder | ||
563 | del pipeline | ||
564 | |||
565 | if torch.cuda.is_available(): | ||
566 | torch.cuda.empty_cache() | ||
567 | 523 | ||
568 | 524 | ||
569 | def main(): | 525 | def main(): |
@@ -806,8 +762,8 @@ def main(): | |||
806 | args.seed, | 762 | args.seed, |
807 | ) | 763 | ) |
808 | 764 | ||
809 | checkpointer = Checkpointer( | 765 | checkpointer = TextualInversionCheckpointer( |
810 | weight_dtype=weight_dtype, | 766 | dtype=weight_dtype, |
811 | train_dataloader=train_dataloader, | 767 | train_dataloader=train_dataloader, |
812 | val_dataloader=val_dataloader, | 768 | val_dataloader=val_dataloader, |
813 | accelerator=accelerator, | 769 | accelerator=accelerator, |
@@ -816,7 +772,7 @@ def main(): | |||
816 | tokenizer=tokenizer, | 772 | tokenizer=tokenizer, |
817 | text_encoder=text_encoder, | 773 | text_encoder=text_encoder, |
818 | ema_embeddings=ema_embeddings, | 774 | ema_embeddings=ema_embeddings, |
819 | scheduler=sample_scheduler, | 775 | sample_scheduler=sample_scheduler, |
820 | placeholder_tokens=args.placeholder_tokens, | 776 | placeholder_tokens=args.placeholder_tokens, |
821 | placeholder_token_ids=placeholder_token_ids, | 777 | placeholder_token_ids=placeholder_token_ids, |
822 | output_dir=output_dir, | 778 | output_dir=output_dir, |