diff options
author | Volpeon <git@volpeon.ink> | 2023-01-16 10:03:05 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-16 10:03:05 +0100 |
commit | d5696615a84a768307e82d13e50b4aef64f69dbd (patch) | |
tree | 47cfaa5b8922edbfd567739ecd770977e339f8d7 | |
parent | Implemented extended Dreambooth training (diff) | |
download | textual-inversion-diff-d5696615a84a768307e82d13e50b4aef64f69dbd.tar.gz textual-inversion-diff-d5696615a84a768307e82d13e50b4aef64f69dbd.tar.bz2 textual-inversion-diff-d5696615a84a768307e82d13e50b4aef64f69dbd.zip |
Extended Dreambooth: Train TI tokens separately
-rw-r--r-- | data/csv.py | 7 | ||||
-rw-r--r-- | train_dreambooth.py | 147 | ||||
-rw-r--r-- | training/functional.py | 1 |
3 files changed, 84 insertions, 71 deletions
diff --git a/data/csv.py b/data/csv.py index 2b1e202..002fdd2 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -15,6 +15,9 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
15 | from models.clip.util import unify_input_ids | 15 | from models.clip.util import unify_input_ids |
16 | 16 | ||
17 | 17 | ||
18 | cache = {} | ||
19 | |||
20 | |||
18 | interpolations = { | 21 | interpolations = { |
19 | "linear": transforms.InterpolationMode.NEAREST, | 22 | "linear": transforms.InterpolationMode.NEAREST, |
20 | "bilinear": transforms.InterpolationMode.BILINEAR, | 23 | "bilinear": transforms.InterpolationMode.BILINEAR, |
@@ -24,9 +27,13 @@ interpolations = { | |||
24 | 27 | ||
25 | 28 | ||
26 | def get_image(path): | 29 | def get_image(path): |
30 | if path in cache: | ||
31 | return cache[path] | ||
32 | |||
27 | image = Image.open(path) | 33 | image = Image.open(path) |
28 | if not image.mode == "RGB": | 34 | if not image.mode == "RGB": |
29 | image = image.convert("RGB") | 35 | image = image.convert("RGB") |
36 | cache[path] = image | ||
30 | return image | 37 | return image |
31 | 38 | ||
32 | 39 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 944256c..05777d0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -510,19 +510,6 @@ def main(): | |||
510 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 510 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
511 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 511 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
512 | 512 | ||
513 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
514 | tokenizer=tokenizer, | ||
515 | embeddings=embeddings, | ||
516 | placeholder_tokens=args.placeholder_tokens, | ||
517 | initializer_tokens=args.initializer_tokens, | ||
518 | num_vectors=args.num_vectors | ||
519 | ) | ||
520 | |||
521 | if len(placeholder_token_ids) != 0: | ||
522 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] | ||
523 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
524 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
525 | |||
526 | if args.scale_lr: | 513 | if args.scale_lr: |
527 | args.learning_rate = ( | 514 | args.learning_rate = ( |
528 | args.learning_rate * args.gradient_accumulation_steps * | 515 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -554,79 +541,98 @@ def main(): | |||
554 | noise_scheduler=noise_scheduler, | 541 | noise_scheduler=noise_scheduler, |
555 | dtype=weight_dtype, | 542 | dtype=weight_dtype, |
556 | seed=args.seed, | 543 | seed=args.seed, |
557 | callbacks_fn=textual_inversion_strategy | ||
558 | ) | 544 | ) |
559 | 545 | ||
560 | # Initial TI | 546 | # Initial TI |
561 | 547 | ||
562 | print("Phase 1: Textual Inversion") | 548 | print("Phase 1: Textual Inversion") |
563 | 549 | ||
550 | ti_lr = 1e-1 | ||
551 | ti_train_epochs = 10 | ||
552 | |||
564 | cur_dir = output_dir.joinpath("1-ti") | 553 | cur_dir = output_dir.joinpath("1-ti") |
565 | cur_dir.mkdir(parents=True, exist_ok=True) | 554 | cur_dir.mkdir(parents=True, exist_ok=True) |
566 | 555 | ||
567 | datamodule = VlpnDataModule( | 556 | for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): |
568 | data_file=args.train_data_file, | 557 | print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") |
569 | batch_size=args.train_batch_size, | ||
570 | tokenizer=tokenizer, | ||
571 | class_subdir=args.class_image_dir, | ||
572 | num_class_images=args.num_class_images, | ||
573 | size=args.resolution, | ||
574 | shuffle=not args.no_tag_shuffle, | ||
575 | template_key=args.train_data_template, | ||
576 | valid_set_size=args.valid_set_size, | ||
577 | valid_set_repeat=args.valid_set_repeat, | ||
578 | seed=args.seed, | ||
579 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
580 | dtype=weight_dtype | ||
581 | ) | ||
582 | datamodule.setup() | ||
583 | 558 | ||
584 | optimizer = optimizer_class( | 559 | cur_subdir = cur_dir.joinpath(placeholder_token) |
585 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 560 | cur_subdir.mkdir(parents=True, exist_ok=True) |
586 | lr=2e-1, | ||
587 | weight_decay=0.0, | ||
588 | ) | ||
589 | 561 | ||
590 | lr_scheduler = get_scheduler( | 562 | placeholder_token_ids, _ = add_placeholder_tokens( |
591 | "linear", | 563 | tokenizer=tokenizer, |
592 | optimizer=optimizer, | 564 | embeddings=embeddings, |
593 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 565 | placeholder_tokens=[placeholder_token], |
594 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 566 | initializer_tokens=[initializer_token], |
595 | train_epochs=30, | 567 | num_vectors=num_vectors |
596 | warmup_epochs=10, | 568 | ) |
597 | ) | ||
598 | 569 | ||
599 | trainer( | 570 | datamodule = VlpnDataModule( |
600 | project="textual_inversion", | 571 | data_file=args.train_data_file, |
601 | train_dataloader=datamodule.train_dataloader, | 572 | batch_size=args.train_batch_size, |
602 | val_dataloader=datamodule.val_dataloader, | 573 | tokenizer=tokenizer, |
603 | optimizer=optimizer, | 574 | class_subdir=args.class_image_dir, |
604 | lr_scheduler=lr_scheduler, | 575 | num_class_images=args.num_class_images, |
605 | num_train_epochs=30, | 576 | size=args.resolution, |
606 | sample_frequency=5, | 577 | shuffle=not args.no_tag_shuffle, |
607 | checkpoint_frequency=9999999, | 578 | template_key=args.train_data_template, |
608 | with_prior_preservation=args.num_class_images != 0, | 579 | valid_set_size=args.valid_set_size, |
609 | prior_loss_weight=args.prior_loss_weight, | 580 | valid_set_repeat=args.valid_set_repeat, |
610 | # -- | 581 | seed=args.seed, |
611 | tokenizer=tokenizer, | 582 | filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), |
612 | sample_scheduler=sample_scheduler, | 583 | dtype=weight_dtype |
613 | output_dir=cur_dir, | 584 | ) |
614 | placeholder_tokens=args.placeholder_tokens, | 585 | datamodule.setup() |
615 | placeholder_token_ids=placeholder_token_ids, | 586 | |
616 | learning_rate=2e-1, | 587 | optimizer = optimizer_class( |
617 | gradient_checkpointing=args.gradient_checkpointing, | 588 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
618 | use_emb_decay=True, | 589 | lr=ti_lr, |
619 | sample_batch_size=args.sample_batch_size, | 590 | weight_decay=0.0, |
620 | sample_num_batches=args.sample_batches, | 591 | ) |
621 | sample_num_steps=args.sample_steps, | 592 | |
622 | sample_image_size=args.sample_image_size, | 593 | lr_scheduler = get_scheduler( |
623 | ) | 594 | "one_cycle", |
595 | optimizer=optimizer, | ||
596 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
597 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
598 | train_epochs=ti_train_epochs, | ||
599 | warmup_epochs=ti_train_epochs // 4, | ||
600 | ) | ||
601 | |||
602 | trainer( | ||
603 | callbacks_fn=textual_inversion_strategy, | ||
604 | project="textual_inversion", | ||
605 | train_dataloader=datamodule.train_dataloader, | ||
606 | val_dataloader=datamodule.val_dataloader, | ||
607 | optimizer=optimizer, | ||
608 | lr_scheduler=lr_scheduler, | ||
609 | num_train_epochs=ti_train_epochs, | ||
610 | sample_frequency=1, | ||
611 | checkpoint_frequency=9999999, | ||
612 | with_prior_preservation=args.num_class_images != 0, | ||
613 | prior_loss_weight=args.prior_loss_weight, | ||
614 | # -- | ||
615 | tokenizer=tokenizer, | ||
616 | sample_scheduler=sample_scheduler, | ||
617 | output_dir=cur_subdir, | ||
618 | placeholder_tokens=[placeholder_token], | ||
619 | placeholder_token_ids=placeholder_token_ids, | ||
620 | learning_rate=ti_lr, | ||
621 | gradient_checkpointing=args.gradient_checkpointing, | ||
622 | use_emb_decay=True, | ||
623 | sample_batch_size=args.sample_batch_size, | ||
624 | sample_num_batches=args.sample_batches, | ||
625 | sample_num_steps=args.sample_steps, | ||
626 | sample_image_size=args.sample_image_size, | ||
627 | ) | ||
628 | |||
629 | embeddings.persist() | ||
624 | 630 | ||
625 | # Dreambooth | 631 | # Dreambooth |
626 | 632 | ||
627 | print("Phase 2: Dreambooth") | 633 | print("Phase 2: Dreambooth") |
628 | 634 | ||
629 | cur_dir = output_dir.joinpath("2db") | 635 | cur_dir = output_dir.joinpath("2-db") |
630 | cur_dir.mkdir(parents=True, exist_ok=True) | 636 | cur_dir.mkdir(parents=True, exist_ok=True) |
631 | 637 | ||
632 | args.seed = (args.seed + 28635) >> 32 | 638 | args.seed = (args.seed + 28635) >> 32 |
@@ -685,6 +691,7 @@ def main(): | |||
685 | ) | 691 | ) |
686 | 692 | ||
687 | trainer( | 693 | trainer( |
694 | callbacks_fn=dreambooth_strategy, | ||
688 | project="dreambooth", | 695 | project="dreambooth", |
689 | train_dataloader=datamodule.train_dataloader, | 696 | train_dataloader=datamodule.train_dataloader, |
690 | val_dataloader=datamodule.val_dataloader, | 697 | val_dataloader=datamodule.val_dataloader, |
@@ -692,14 +699,12 @@ def main(): | |||
692 | lr_scheduler=lr_scheduler, | 699 | lr_scheduler=lr_scheduler, |
693 | num_train_epochs=args.num_train_epochs, | 700 | num_train_epochs=args.num_train_epochs, |
694 | sample_frequency=args.sample_frequency, | 701 | sample_frequency=args.sample_frequency, |
695 | checkpoint_frequency=args.checkpoint_frequency, | ||
696 | with_prior_preservation=args.num_class_images != 0, | 702 | with_prior_preservation=args.num_class_images != 0, |
697 | prior_loss_weight=args.prior_loss_weight, | 703 | prior_loss_weight=args.prior_loss_weight, |
698 | # -- | 704 | # -- |
699 | tokenizer=tokenizer, | 705 | tokenizer=tokenizer, |
700 | sample_scheduler=sample_scheduler, | 706 | sample_scheduler=sample_scheduler, |
701 | output_dir=cur_dir, | 707 | output_dir=cur_dir, |
702 | gradient_checkpointing=args.gradient_checkpointing, | ||
703 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 708 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
704 | max_grad_norm=args.max_grad_norm, | 709 | max_grad_norm=args.max_grad_norm, |
705 | use_ema=args.use_ema, | 710 | use_ema=args.use_ema, |
diff --git a/training/functional.py b/training/functional.py index f5c111e..1b6162b 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -486,6 +486,7 @@ def train_loop( | |||
486 | if accelerator.is_main_process: | 486 | if accelerator.is_main_process: |
487 | print("Interrupted") | 487 | print("Interrupted") |
488 | on_checkpoint(global_step + global_step_offset, "end") | 488 | on_checkpoint(global_step + global_step_offset, "end") |
489 | raise KeyboardInterrupt | ||
489 | 490 | ||
490 | 491 | ||
491 | def train( | 492 | def train( |