summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 22:26:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 22:26:43 +0100
commit3f922880475c2c0a5679987d4a9a43606e838566 (patch)
tree757746927e34aa7fddff1e44c837b489233029d7 /train_ti.py
parentRestored functional trainer (diff)
downloadtextual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.tar.gz
textual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.tar.bz2
textual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.zip
Added Dreambooth strategy
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py46
1 files changed, 23 insertions, 23 deletions
diff --git a/train_ti.py b/train_ti.py
index 77dec12..2497519 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -557,15 +557,6 @@ def main():
557 else: 557 else:
558 optimizer_class = torch.optim.AdamW 558 optimizer_class = torch.optim.AdamW
559 559
560 optimizer = optimizer_class(
561 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
562 lr=args.learning_rate,
563 betas=(args.adam_beta1, args.adam_beta2),
564 weight_decay=args.adam_weight_decay,
565 eps=args.adam_epsilon,
566 amsgrad=args.adam_amsgrad,
567 )
568
569 weight_dtype = torch.float32 560 weight_dtype = torch.float32
570 if args.mixed_precision == "fp16": 561 if args.mixed_precision == "fp16":
571 weight_dtype = torch.float16 562 weight_dtype = torch.float16
@@ -624,6 +615,29 @@ def main():
624 args.sample_steps 615 args.sample_steps
625 ) 616 )
626 617
618 trainer = partial(
619 train,
620 accelerator=accelerator,
621 unet=unet,
622 text_encoder=text_encoder,
623 vae=vae,
624 noise_scheduler=noise_scheduler,
625 train_dataloader=train_dataloader,
626 val_dataloader=val_dataloader,
627 dtype=weight_dtype,
628 seed=args.seed,
629 callbacks_fn=textual_inversion_strategy
630 )
631
632 optimizer = optimizer_class(
633 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
634 lr=args.learning_rate,
635 betas=(args.adam_beta1, args.adam_beta2),
636 weight_decay=args.adam_weight_decay,
637 eps=args.adam_epsilon,
638 amsgrad=args.adam_amsgrad,
639 )
640
627 if args.find_lr: 641 if args.find_lr:
628 lr_scheduler = None 642 lr_scheduler = None
629 else: 643 else:
@@ -642,20 +656,6 @@ def main():
642 warmup_epochs=args.lr_warmup_epochs, 656 warmup_epochs=args.lr_warmup_epochs,
643 ) 657 )
644 658
645 trainer = partial(
646 train,
647 accelerator=accelerator,
648 unet=unet,
649 text_encoder=text_encoder,
650 vae=vae,
651 noise_scheduler=noise_scheduler,
652 train_dataloader=train_dataloader,
653 val_dataloader=val_dataloader,
654 dtype=weight_dtype,
655 seed=args.seed,
656 callbacks_fn=textual_inversion_strategy
657 )
658
659 trainer( 659 trainer(
660 optimizer=optimizer, 660 optimizer=optimizer,
661 lr_scheduler=lr_scheduler, 661 lr_scheduler=lr_scheduler,