diff options
| -rw-r--r-- | train_ti.py | 49 | ||||
| -rw-r--r-- | trainer_old/base.py | 14 | ||||
| -rw-r--r-- | training/functional.py | 75 |
3 files changed, 101 insertions, 37 deletions
diff --git a/train_ti.py b/train_ti.py index 78c1b5c..97e4e72 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -17,7 +17,7 @@ from slugify import slugify | |||
| 17 | from util import load_config, load_embeddings_from_dir | 17 | from util import load_config, load_embeddings_from_dir |
| 18 | from data.csv import VlpnDataModule, VlpnDataItem | 18 | from data.csv import VlpnDataModule, VlpnDataItem |
| 19 | from trainer_old.base import Checkpointer | 19 | from trainer_old.base import Checkpointer |
| 20 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | 20 | from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models |
| 21 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
| 22 | from training.lr import LRFinder | 22 | from training.lr import LRFinder |
| 23 | from training.util import EMAModel, save_args | 23 | from training.util import EMAModel, save_args |
| @@ -703,17 +703,27 @@ def main(): | |||
| 703 | warmup_epochs=args.lr_warmup_epochs, | 703 | warmup_epochs=args.lr_warmup_epochs, |
| 704 | ) | 704 | ) |
| 705 | 705 | ||
| 706 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 707 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 708 | ) | ||
| 709 | |||
| 710 | vae.to(accelerator.device, dtype=weight_dtype) | ||
| 711 | |||
| 712 | if args.use_ema: | 706 | if args.use_ema: |
| 713 | ema_embeddings.to(accelerator.device) | 707 | ema_embeddings.to(accelerator.device) |
| 714 | 708 | ||
| 715 | if args.gradient_checkpointing: | 709 | trainer = partial( |
| 716 | unet.train() | 710 | train, |
| 711 | accelerator=accelerator, | ||
| 712 | vae=vae, | ||
| 713 | unet=unet, | ||
| 714 | text_encoder=text_encoder, | ||
| 715 | noise_scheduler=noise_scheduler, | ||
| 716 | train_dataloader=train_dataloader, | ||
| 717 | val_dataloader=val_dataloader, | ||
| 718 | dtype=weight_dtype, | ||
| 719 | seed=args.seed, | ||
| 720 | ) | ||
| 721 | |||
| 722 | def on_prepare(): | ||
| 723 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
| 724 | |||
| 725 | if args.gradient_checkpointing: | ||
| 726 | unet.train() | ||
| 717 | 727 | ||
| 718 | @contextmanager | 728 | @contextmanager |
| 719 | def on_train(epoch: int): | 729 | def on_train(epoch: int): |
| @@ -752,16 +762,6 @@ def main(): | |||
| 752 | return {"ema_decay": ema_embeddings.decay} | 762 | return {"ema_decay": ema_embeddings.decay} |
| 753 | return {} | 763 | return {} |
| 754 | 764 | ||
| 755 | loss_step_ = partial( | ||
| 756 | loss_step, | ||
| 757 | vae, | ||
| 758 | noise_scheduler, | ||
| 759 | unet, | ||
| 760 | text_encoder, | ||
| 761 | args.prior_loss_weight, | ||
| 762 | args.seed, | ||
| 763 | ) | ||
| 764 | |||
| 765 | checkpointer = TextualInversionCheckpointer( | 765 | checkpointer = TextualInversionCheckpointer( |
| 766 | dtype=weight_dtype, | 766 | dtype=weight_dtype, |
| 767 | train_dataloader=train_dataloader, | 767 | train_dataloader=train_dataloader, |
| @@ -803,18 +803,15 @@ def main(): | |||
| 803 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | 803 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
| 804 | plt.close() | 804 | plt.close() |
| 805 | else: | 805 | else: |
| 806 | train_loop( | 806 | trainer( |
| 807 | accelerator=accelerator, | ||
| 808 | optimizer=optimizer, | 807 | optimizer=optimizer, |
| 809 | lr_scheduler=lr_scheduler, | 808 | lr_scheduler=lr_scheduler, |
| 810 | model=text_encoder, | 809 | num_train_epochs=args.num_train_epochs, |
| 811 | train_dataloader=train_dataloader, | ||
| 812 | val_dataloader=val_dataloader, | ||
| 813 | loss_step=loss_step_, | ||
| 814 | sample_frequency=args.sample_frequency, | 810 | sample_frequency=args.sample_frequency, |
| 815 | checkpoint_frequency=args.checkpoint_frequency, | 811 | checkpoint_frequency=args.checkpoint_frequency, |
| 816 | global_step_offset=global_step_offset, | 812 | global_step_offset=global_step_offset, |
| 817 | num_epochs=args.num_train_epochs, | 813 | prior_loss_weight=args.prior_loss_weight, |
| 814 | on_prepare=on_prepare, | ||
| 818 | on_log=on_log, | 815 | on_log=on_log, |
| 819 | on_train=on_train, | 816 | on_train=on_train, |
| 820 | on_after_optimize=on_after_optimize, | 817 | on_after_optimize=on_after_optimize, |
diff --git a/trainer_old/base.py b/trainer_old/base.py index 1f85e71..5903d96 100644 --- a/trainer_old/base.py +++ b/trainer_old/base.py | |||
| @@ -174,19 +174,13 @@ class TrainingStrategy(): | |||
| 174 | 174 | ||
| 175 | @contextmanager | 175 | @contextmanager |
| 176 | def on_train(self, epoch: int): | 176 | def on_train(self, epoch: int): |
| 177 | try: | 177 | self.tokenizer.train() |
| 178 | self.tokenizer.train() | 178 | yield |
| 179 | yield | ||
| 180 | finally: | ||
| 181 | pass | ||
| 182 | 179 | ||
| 183 | @contextmanager | 180 | @contextmanager |
| 184 | def on_eval(self): | 181 | def on_eval(self): |
| 185 | try: | 182 | self.tokenizer.eval() |
| 186 | self.tokenizer.eval() | 183 | yield |
| 187 | yield | ||
| 188 | finally: | ||
| 189 | pass | ||
| 190 | 184 | ||
| 191 | def on_before_optimize(self, epoch: int): | 185 | def on_before_optimize(self, epoch: int): |
| 192 | ... | 186 | ... |
diff --git a/training/functional.py b/training/functional.py index c5b514a..1f2ca6d 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import math | 1 | import math |
| 2 | from contextlib import _GeneratorContextManager, nullcontext | 2 | from contextlib import _GeneratorContextManager, nullcontext |
| 3 | from typing import Callable, Any, Tuple, Union | 3 | from typing import Callable, Any, Tuple, Union, Optional |
| 4 | from functools import partial | ||
| 4 | 5 | ||
| 5 | import torch | 6 | import torch |
| 6 | import torch.nn.functional as F | 7 | import torch.nn.functional as F |
| @@ -376,3 +377,75 @@ def train_loop( | |||
| 376 | print("Interrupted") | 377 | print("Interrupted") |
| 377 | on_checkpoint(global_step + global_step_offset, "end") | 378 | on_checkpoint(global_step + global_step_offset, "end") |
| 378 | accelerator.end_training() | 379 | accelerator.end_training() |
| 380 | |||
| 381 | |||
| 382 | def train( | ||
| 383 | accelerator: Accelerator, | ||
| 384 | unet: UNet2DConditionModel, | ||
| 385 | text_encoder: CLIPTextModel, | ||
| 386 | vae: AutoencoderKL, | ||
| 387 | noise_scheduler: DDPMScheduler, | ||
| 388 | train_dataloader: DataLoader, | ||
| 389 | val_dataloader: DataLoader, | ||
| 390 | dtype: torch.dtype, | ||
| 391 | seed: int, | ||
| 392 | optimizer: torch.optim.Optimizer, | ||
| 393 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 394 | num_train_epochs: int = 100, | ||
| 395 | sample_frequency: int = 20, | ||
| 396 | checkpoint_frequency: int = 50, | ||
| 397 | global_step_offset: int = 0, | ||
| 398 | prior_loss_weight: float = 0, | ||
| 399 | on_prepare: Callable[[], dict[str, Any]] = const({}), | ||
| 400 | on_log: Callable[[], dict[str, Any]] = const({}), | ||
| 401 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | ||
| 402 | on_before_optimize: Callable[[int], None] = const(), | ||
| 403 | on_after_optimize: Callable[[float], None] = const(), | ||
| 404 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), | ||
| 405 | on_sample: Callable[[int], None] = const(), | ||
| 406 | on_checkpoint: Callable[[int, str], None] = const(), | ||
| 407 | ): | ||
| 408 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 409 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 410 | ) | ||
| 411 | |||
| 412 | vae.to(accelerator.device, dtype=dtype) | ||
| 413 | |||
| 414 | for model in (unet, text_encoder, vae): | ||
| 415 | model.requires_grad_(False) | ||
| 416 | model.eval() | ||
| 417 | |||
| 418 | on_prepare() | ||
| 419 | |||
| 420 | loss_step_ = partial( | ||
| 421 | loss_step, | ||
| 422 | vae, | ||
| 423 | noise_scheduler, | ||
| 424 | unet, | ||
| 425 | text_encoder, | ||
| 426 | prior_loss_weight, | ||
| 427 | seed, | ||
| 428 | ) | ||
| 429 | |||
| 430 | train_loop( | ||
| 431 | accelerator=accelerator, | ||
| 432 | optimizer=optimizer, | ||
| 433 | lr_scheduler=lr_scheduler, | ||
| 434 | model=text_encoder, | ||
| 435 | train_dataloader=train_dataloader, | ||
| 436 | val_dataloader=val_dataloader, | ||
| 437 | loss_step=loss_step_, | ||
| 438 | sample_frequency=sample_frequency, | ||
| 439 | checkpoint_frequency=checkpoint_frequency, | ||
| 440 | global_step_offset=global_step_offset, | ||
| 441 | num_epochs=num_train_epochs, | ||
| 442 | on_log=on_log, | ||
| 443 | on_train=on_train, | ||
| 444 | on_before_optimize=on_before_optimize, | ||
| 445 | on_after_optimize=on_after_optimize, | ||
| 446 | on_eval=on_eval, | ||
| 447 | on_sample=on_sample, | ||
| 448 | on_checkpoint=on_checkpoint, | ||
| 449 | ) | ||
| 450 | |||
| 451 | accelerator.free_memory() | ||
