summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py49
-rw-r--r--trainer_old/base.py14
-rw-r--r--training/functional.py75
3 files changed, 101 insertions, 37 deletions
diff --git a/train_ti.py b/train_ti.py
index 78c1b5c..97e4e72 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -17,7 +17,7 @@ from slugify import slugify
17from util import load_config, load_embeddings_from_dir 17from util import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, VlpnDataItem 18from data.csv import VlpnDataModule, VlpnDataItem
19from trainer_old.base import Checkpointer 19from trainer_old.base import Checkpointer
20from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 20from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
21from training.optimization import get_scheduler 21from training.optimization import get_scheduler
22from training.lr import LRFinder 22from training.lr import LRFinder
23from training.util import EMAModel, save_args 23from training.util import EMAModel, save_args
@@ -703,17 +703,27 @@ def main():
703 warmup_epochs=args.lr_warmup_epochs, 703 warmup_epochs=args.lr_warmup_epochs,
704 ) 704 )
705 705
706 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
707 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
708 )
709
710 vae.to(accelerator.device, dtype=weight_dtype)
711
712 if args.use_ema: 706 if args.use_ema:
713 ema_embeddings.to(accelerator.device) 707 ema_embeddings.to(accelerator.device)
714 708
715 if args.gradient_checkpointing: 709 trainer = partial(
716 unet.train() 710 train,
711 accelerator=accelerator,
712 vae=vae,
713 unet=unet,
714 text_encoder=text_encoder,
715 noise_scheduler=noise_scheduler,
716 train_dataloader=train_dataloader,
717 val_dataloader=val_dataloader,
718 dtype=weight_dtype,
719 seed=args.seed,
720 )
721
722 def on_prepare():
723 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
724
725 if args.gradient_checkpointing:
726 unet.train()
717 727
718 @contextmanager 728 @contextmanager
719 def on_train(epoch: int): 729 def on_train(epoch: int):
@@ -752,16 +762,6 @@ def main():
752 return {"ema_decay": ema_embeddings.decay} 762 return {"ema_decay": ema_embeddings.decay}
753 return {} 763 return {}
754 764
755 loss_step_ = partial(
756 loss_step,
757 vae,
758 noise_scheduler,
759 unet,
760 text_encoder,
761 args.prior_loss_weight,
762 args.seed,
763 )
764
765 checkpointer = TextualInversionCheckpointer( 765 checkpointer = TextualInversionCheckpointer(
766 dtype=weight_dtype, 766 dtype=weight_dtype,
767 train_dataloader=train_dataloader, 767 train_dataloader=train_dataloader,
@@ -803,18 +803,15 @@ def main():
803 plt.savefig(output_dir.joinpath("lr.png"), dpi=300) 803 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
804 plt.close() 804 plt.close()
805 else: 805 else:
806 train_loop( 806 trainer(
807 accelerator=accelerator,
808 optimizer=optimizer, 807 optimizer=optimizer,
809 lr_scheduler=lr_scheduler, 808 lr_scheduler=lr_scheduler,
810 model=text_encoder, 809 num_train_epochs=args.num_train_epochs,
811 train_dataloader=train_dataloader,
812 val_dataloader=val_dataloader,
813 loss_step=loss_step_,
814 sample_frequency=args.sample_frequency, 810 sample_frequency=args.sample_frequency,
815 checkpoint_frequency=args.checkpoint_frequency, 811 checkpoint_frequency=args.checkpoint_frequency,
816 global_step_offset=global_step_offset, 812 global_step_offset=global_step_offset,
817 num_epochs=args.num_train_epochs, 813 prior_loss_weight=args.prior_loss_weight,
814 on_prepare=on_prepare,
818 on_log=on_log, 815 on_log=on_log,
819 on_train=on_train, 816 on_train=on_train,
820 on_after_optimize=on_after_optimize, 817 on_after_optimize=on_after_optimize,
diff --git a/trainer_old/base.py b/trainer_old/base.py
index 1f85e71..5903d96 100644
--- a/trainer_old/base.py
+++ b/trainer_old/base.py
@@ -174,19 +174,13 @@ class TrainingStrategy():
174 174
175 @contextmanager 175 @contextmanager
176 def on_train(self, epoch: int): 176 def on_train(self, epoch: int):
177 try: 177 self.tokenizer.train()
178 self.tokenizer.train() 178 yield
179 yield
180 finally:
181 pass
182 179
183 @contextmanager 180 @contextmanager
184 def on_eval(self): 181 def on_eval(self):
185 try: 182 self.tokenizer.eval()
186 self.tokenizer.eval() 183 yield
187 yield
188 finally:
189 pass
190 184
191 def on_before_optimize(self, epoch: int): 185 def on_before_optimize(self, epoch: int):
192 ... 186 ...
diff --git a/training/functional.py b/training/functional.py
index c5b514a..1f2ca6d 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,6 +1,7 @@
1import math 1import math
2from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 3from typing import Callable, Any, Tuple, Union, Optional
4from functools import partial
4 5
5import torch 6import torch
6import torch.nn.functional as F 7import torch.nn.functional as F
@@ -376,3 +377,75 @@ def train_loop(
376 print("Interrupted") 377 print("Interrupted")
377 on_checkpoint(global_step + global_step_offset, "end") 378 on_checkpoint(global_step + global_step_offset, "end")
378 accelerator.end_training() 379 accelerator.end_training()
380
381
382def train(
383 accelerator: Accelerator,
384 unet: UNet2DConditionModel,
385 text_encoder: CLIPTextModel,
386 vae: AutoencoderKL,
387 noise_scheduler: DDPMScheduler,
388 train_dataloader: DataLoader,
389 val_dataloader: DataLoader,
390 dtype: torch.dtype,
391 seed: int,
392 optimizer: torch.optim.Optimizer,
393 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
394 num_train_epochs: int = 100,
395 sample_frequency: int = 20,
396 checkpoint_frequency: int = 50,
397 global_step_offset: int = 0,
398 prior_loss_weight: float = 0,
399 on_prepare: Callable[[], dict[str, Any]] = const({}),
400 on_log: Callable[[], dict[str, Any]] = const({}),
401 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
402 on_before_optimize: Callable[[int], None] = const(),
403 on_after_optimize: Callable[[float], None] = const(),
404 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
405 on_sample: Callable[[int], None] = const(),
406 on_checkpoint: Callable[[int, str], None] = const(),
407):
408 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
409 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
410 )
411
412 vae.to(accelerator.device, dtype=dtype)
413
414 for model in (unet, text_encoder, vae):
415 model.requires_grad_(False)
416 model.eval()
417
418 on_prepare()
419
420 loss_step_ = partial(
421 loss_step,
422 vae,
423 noise_scheduler,
424 unet,
425 text_encoder,
426 prior_loss_weight,
427 seed,
428 )
429
430 train_loop(
431 accelerator=accelerator,
432 optimizer=optimizer,
433 lr_scheduler=lr_scheduler,
434 model=text_encoder,
435 train_dataloader=train_dataloader,
436 val_dataloader=val_dataloader,
437 loss_step=loss_step_,
438 sample_frequency=sample_frequency,
439 checkpoint_frequency=checkpoint_frequency,
440 global_step_offset=global_step_offset,
441 num_epochs=num_train_epochs,
442 on_log=on_log,
443 on_train=on_train,
444 on_before_optimize=on_before_optimize,
445 on_after_optimize=on_after_optimize,
446 on_eval=on_eval,
447 on_sample=on_sample,
448 on_checkpoint=on_checkpoint,
449 )
450
451 accelerator.free_memory()