summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 10:03:05 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 10:03:05 +0100
commitd5696615a84a768307e82d13e50b4aef64f69dbd (patch)
tree47cfaa5b8922edbfd567739ecd770977e339f8d7
parentImplemented extended Dreambooth training (diff)
downloadtextual-inversion-diff-d5696615a84a768307e82d13e50b4aef64f69dbd.tar.gz
textual-inversion-diff-d5696615a84a768307e82d13e50b4aef64f69dbd.tar.bz2
textual-inversion-diff-d5696615a84a768307e82d13e50b4aef64f69dbd.zip
Extended Dreambooth: Train TI tokens separately
-rw-r--r--data/csv.py7
-rw-r--r--train_dreambooth.py147
-rw-r--r--training/functional.py1
3 files changed, 84 insertions, 71 deletions
diff --git a/data/csv.py b/data/csv.py
index 2b1e202..002fdd2 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -15,6 +15,9 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt
15from models.clip.util import unify_input_ids 15from models.clip.util import unify_input_ids
16 16
17 17
18cache = {}
19
20
18interpolations = { 21interpolations = {
19 "linear": transforms.InterpolationMode.NEAREST, 22 "linear": transforms.InterpolationMode.NEAREST,
20 "bilinear": transforms.InterpolationMode.BILINEAR, 23 "bilinear": transforms.InterpolationMode.BILINEAR,
@@ -24,9 +27,13 @@ interpolations = {
24 27
25 28
26def get_image(path): 29def get_image(path):
30 if path in cache:
31 return cache[path]
32
27 image = Image.open(path) 33 image = Image.open(path)
28 if not image.mode == "RGB": 34 if not image.mode == "RGB":
29 image = image.convert("RGB") 35 image = image.convert("RGB")
36 cache[path] = image
30 return image 37 return image
31 38
32 39
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 944256c..05777d0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -510,19 +510,6 @@ def main():
510 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 510 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
511 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 511 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
512 512
513 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
514 tokenizer=tokenizer,
515 embeddings=embeddings,
516 placeholder_tokens=args.placeholder_tokens,
517 initializer_tokens=args.initializer_tokens,
518 num_vectors=args.num_vectors
519 )
520
521 if len(placeholder_token_ids) != 0:
522 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
523 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
524 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
525
526 if args.scale_lr: 513 if args.scale_lr:
527 args.learning_rate = ( 514 args.learning_rate = (
528 args.learning_rate * args.gradient_accumulation_steps * 515 args.learning_rate * args.gradient_accumulation_steps *
@@ -554,79 +541,98 @@ def main():
554 noise_scheduler=noise_scheduler, 541 noise_scheduler=noise_scheduler,
555 dtype=weight_dtype, 542 dtype=weight_dtype,
556 seed=args.seed, 543 seed=args.seed,
557 callbacks_fn=textual_inversion_strategy
558 ) 544 )
559 545
560 # Initial TI 546 # Initial TI
561 547
562 print("Phase 1: Textual Inversion") 548 print("Phase 1: Textual Inversion")
563 549
550 ti_lr = 1e-1
551 ti_train_epochs = 10
552
564 cur_dir = output_dir.joinpath("1-ti") 553 cur_dir = output_dir.joinpath("1-ti")
565 cur_dir.mkdir(parents=True, exist_ok=True) 554 cur_dir.mkdir(parents=True, exist_ok=True)
566 555
567 datamodule = VlpnDataModule( 556 for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors):
568 data_file=args.train_data_file, 557 print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})")
569 batch_size=args.train_batch_size,
570 tokenizer=tokenizer,
571 class_subdir=args.class_image_dir,
572 num_class_images=args.num_class_images,
573 size=args.resolution,
574 shuffle=not args.no_tag_shuffle,
575 template_key=args.train_data_template,
576 valid_set_size=args.valid_set_size,
577 valid_set_repeat=args.valid_set_repeat,
578 seed=args.seed,
579 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
580 dtype=weight_dtype
581 )
582 datamodule.setup()
583 558
584 optimizer = optimizer_class( 559 cur_subdir = cur_dir.joinpath(placeholder_token)
585 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 560 cur_subdir.mkdir(parents=True, exist_ok=True)
586 lr=2e-1,
587 weight_decay=0.0,
588 )
589 561
590 lr_scheduler = get_scheduler( 562 placeholder_token_ids, _ = add_placeholder_tokens(
591 "linear", 563 tokenizer=tokenizer,
592 optimizer=optimizer, 564 embeddings=embeddings,
593 num_training_steps_per_epoch=len(datamodule.train_dataloader), 565 placeholder_tokens=[placeholder_token],
594 gradient_accumulation_steps=args.gradient_accumulation_steps, 566 initializer_tokens=[initializer_token],
595 train_epochs=30, 567 num_vectors=num_vectors
596 warmup_epochs=10, 568 )
597 )
598 569
599 trainer( 570 datamodule = VlpnDataModule(
600 project="textual_inversion", 571 data_file=args.train_data_file,
601 train_dataloader=datamodule.train_dataloader, 572 batch_size=args.train_batch_size,
602 val_dataloader=datamodule.val_dataloader, 573 tokenizer=tokenizer,
603 optimizer=optimizer, 574 class_subdir=args.class_image_dir,
604 lr_scheduler=lr_scheduler, 575 num_class_images=args.num_class_images,
605 num_train_epochs=30, 576 size=args.resolution,
606 sample_frequency=5, 577 shuffle=not args.no_tag_shuffle,
607 checkpoint_frequency=9999999, 578 template_key=args.train_data_template,
608 with_prior_preservation=args.num_class_images != 0, 579 valid_set_size=args.valid_set_size,
609 prior_loss_weight=args.prior_loss_weight, 580 valid_set_repeat=args.valid_set_repeat,
610 # -- 581 seed=args.seed,
611 tokenizer=tokenizer, 582 filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections),
612 sample_scheduler=sample_scheduler, 583 dtype=weight_dtype
613 output_dir=cur_dir, 584 )
614 placeholder_tokens=args.placeholder_tokens, 585 datamodule.setup()
615 placeholder_token_ids=placeholder_token_ids, 586
616 learning_rate=2e-1, 587 optimizer = optimizer_class(
617 gradient_checkpointing=args.gradient_checkpointing, 588 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
618 use_emb_decay=True, 589 lr=ti_lr,
619 sample_batch_size=args.sample_batch_size, 590 weight_decay=0.0,
620 sample_num_batches=args.sample_batches, 591 )
621 sample_num_steps=args.sample_steps, 592
622 sample_image_size=args.sample_image_size, 593 lr_scheduler = get_scheduler(
623 ) 594 "one_cycle",
595 optimizer=optimizer,
596 num_training_steps_per_epoch=len(datamodule.train_dataloader),
597 gradient_accumulation_steps=args.gradient_accumulation_steps,
598 train_epochs=ti_train_epochs,
599 warmup_epochs=ti_train_epochs // 4,
600 )
601
602 trainer(
603 callbacks_fn=textual_inversion_strategy,
604 project="textual_inversion",
605 train_dataloader=datamodule.train_dataloader,
606 val_dataloader=datamodule.val_dataloader,
607 optimizer=optimizer,
608 lr_scheduler=lr_scheduler,
609 num_train_epochs=ti_train_epochs,
610 sample_frequency=1,
611 checkpoint_frequency=9999999,
612 with_prior_preservation=args.num_class_images != 0,
613 prior_loss_weight=args.prior_loss_weight,
614 # --
615 tokenizer=tokenizer,
616 sample_scheduler=sample_scheduler,
617 output_dir=cur_subdir,
618 placeholder_tokens=[placeholder_token],
619 placeholder_token_ids=placeholder_token_ids,
620 learning_rate=ti_lr,
621 gradient_checkpointing=args.gradient_checkpointing,
622 use_emb_decay=True,
623 sample_batch_size=args.sample_batch_size,
624 sample_num_batches=args.sample_batches,
625 sample_num_steps=args.sample_steps,
626 sample_image_size=args.sample_image_size,
627 )
628
629 embeddings.persist()
624 630
625 # Dreambooth 631 # Dreambooth
626 632
627 print("Phase 2: Dreambooth") 633 print("Phase 2: Dreambooth")
628 634
629 cur_dir = output_dir.joinpath("2db") 635 cur_dir = output_dir.joinpath("2-db")
630 cur_dir.mkdir(parents=True, exist_ok=True) 636 cur_dir.mkdir(parents=True, exist_ok=True)
631 637
632 args.seed = (args.seed + 28635) >> 32 638 args.seed = (args.seed + 28635) >> 32
@@ -685,6 +691,7 @@ def main():
685 ) 691 )
686 692
687 trainer( 693 trainer(
694 callbacks_fn=dreambooth_strategy,
688 project="dreambooth", 695 project="dreambooth",
689 train_dataloader=datamodule.train_dataloader, 696 train_dataloader=datamodule.train_dataloader,
690 val_dataloader=datamodule.val_dataloader, 697 val_dataloader=datamodule.val_dataloader,
@@ -692,14 +699,12 @@ def main():
692 lr_scheduler=lr_scheduler, 699 lr_scheduler=lr_scheduler,
693 num_train_epochs=args.num_train_epochs, 700 num_train_epochs=args.num_train_epochs,
694 sample_frequency=args.sample_frequency, 701 sample_frequency=args.sample_frequency,
695 checkpoint_frequency=args.checkpoint_frequency,
696 with_prior_preservation=args.num_class_images != 0, 702 with_prior_preservation=args.num_class_images != 0,
697 prior_loss_weight=args.prior_loss_weight, 703 prior_loss_weight=args.prior_loss_weight,
698 # -- 704 # --
699 tokenizer=tokenizer, 705 tokenizer=tokenizer,
700 sample_scheduler=sample_scheduler, 706 sample_scheduler=sample_scheduler,
701 output_dir=cur_dir, 707 output_dir=cur_dir,
702 gradient_checkpointing=args.gradient_checkpointing,
703 train_text_encoder_epochs=args.train_text_encoder_epochs, 708 train_text_encoder_epochs=args.train_text_encoder_epochs,
704 max_grad_norm=args.max_grad_norm, 709 max_grad_norm=args.max_grad_norm,
705 use_ema=args.use_ema, 710 use_ema=args.use_ema,
diff --git a/training/functional.py b/training/functional.py
index f5c111e..1b6162b 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -486,6 +486,7 @@ def train_loop(
486 if accelerator.is_main_process: 486 if accelerator.is_main_process:
487 print("Interrupted") 487 print("Interrupted")
488 on_checkpoint(global_step + global_step_offset, "end") 488 on_checkpoint(global_step + global_step_offset, "end")
489 raise KeyboardInterrupt
489 490
490 491
491def train( 492def train(