summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 17:09:01 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 17:09:01 +0100
commit36440e48ce279872d6e736bcb1bf57d13da73a11 (patch)
tree8ba9593d8a887517c70b01932c137c9c3f759e8f /train_ti.py
parentMore training adjustments (diff)
downloadtextual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.gz
textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.bz2
textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.zip
Moved multi-TI code from Dreambooth to TI script
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py221
1 files changed, 114 insertions, 107 deletions
diff --git a/train_ti.py b/train_ti.py
index 7aecdef..adba8d4 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -51,6 +51,7 @@ def parse_args():
51 parser.add_argument( 51 parser.add_argument(
52 "--train_data_template", 52 "--train_data_template",
53 type=str, 53 type=str,
54 nargs='*',
54 default="template", 55 default="template",
55 ) 56 )
56 parser.add_argument( 57 parser.add_argument(
@@ -468,11 +469,17 @@ def parse_args():
468 args.num_vectors = 1 469 args.num_vectors = 1
469 470
470 if isinstance(args.num_vectors, int): 471 if isinstance(args.num_vectors, int):
471 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) 472 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
472 473
473 if len(args.placeholder_tokens) != len(args.num_vectors): 474 if len(args.placeholder_tokens) != len(args.num_vectors):
474 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 475 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
475 476
477 if isinstance(args.train_data_template, str):
478 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
479
480 if len(args.placeholder_tokens) != len(args.train_data_template):
481 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
482
476 if isinstance(args.collection, str): 483 if isinstance(args.collection, str):
477 args.collection = [args.collection] 484 args.collection = [args.collection]
478 485
@@ -507,6 +514,8 @@ def main():
507 514
508 set_seed(args.seed) 515 set_seed(args.seed)
509 516
517 seed_generator = torch.Generator().manual_seed(args.seed)
518
510 save_args(output_dir, args) 519 save_args(output_dir, args)
511 520
512 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 521 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -531,19 +540,6 @@ def main():
531 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 540 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
532 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 541 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
533 542
534 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
535 tokenizer=tokenizer,
536 embeddings=embeddings,
537 placeholder_tokens=args.placeholder_tokens,
538 initializer_tokens=args.initializer_tokens,
539 num_vectors=args.num_vectors
540 )
541
542 if len(placeholder_token_ids) != 0:
543 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
544 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
545 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
546
547 if args.scale_lr: 543 if args.scale_lr:
548 args.learning_rate = ( 544 args.learning_rate = (
549 args.learning_rate * args.gradient_accumulation_steps * 545 args.learning_rate * args.gradient_accumulation_steps *
@@ -566,43 +562,6 @@ def main():
566 elif args.mixed_precision == "bf16": 562 elif args.mixed_precision == "bf16":
567 weight_dtype = torch.bfloat16 563 weight_dtype = torch.bfloat16
568 564
569 datamodule = VlpnDataModule(
570 data_file=args.train_data_file,
571 batch_size=args.train_batch_size,
572 tokenizer=tokenizer,
573 class_subdir=args.class_image_dir,
574 num_class_images=args.num_class_images,
575 size=args.resolution,
576 num_buckets=args.num_buckets,
577 progressive_buckets=args.progressive_buckets,
578 bucket_step_size=args.bucket_step_size,
579 bucket_max_pixels=args.bucket_max_pixels,
580 dropout=args.tag_dropout,
581 shuffle=not args.no_tag_shuffle,
582 template_key=args.train_data_template,
583 valid_set_size=args.valid_set_size,
584 train_set_pad=args.train_set_pad,
585 valid_set_pad=args.valid_set_pad,
586 seed=args.seed,
587 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
588 dtype=weight_dtype
589 )
590 datamodule.setup()
591
592 if args.num_class_images != 0:
593 generate_class_images(
594 accelerator,
595 text_encoder,
596 vae,
597 unet,
598 tokenizer,
599 sample_scheduler,
600 datamodule.train_dataset,
601 args.sample_batch_size,
602 args.sample_image_size,
603 args.sample_steps
604 )
605
606 trainer = partial( 565 trainer = partial(
607 train, 566 train,
608 accelerator=accelerator, 567 accelerator=accelerator,
@@ -615,63 +574,111 @@ def main():
615 callbacks_fn=textual_inversion_strategy 574 callbacks_fn=textual_inversion_strategy
616 ) 575 )
617 576
618 optimizer = optimizer_class( 577 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
619 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 578 range(len(args.placeholder_tokens)),
620 lr=args.learning_rate, 579 args.placeholder_tokens,
621 betas=(args.adam_beta1, args.adam_beta2), 580 args.initializer_tokens,
622 weight_decay=args.adam_weight_decay, 581 args.num_vectors,
623 eps=args.adam_epsilon, 582 args.train_data_template
624 amsgrad=args.adam_amsgrad, 583 ):
625 ) 584 cur_subdir = output_dir.joinpath(placeholder_token)
585 cur_subdir.mkdir(parents=True, exist_ok=True)
586
587 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
588 tokenizer=tokenizer,
589 embeddings=embeddings,
590 placeholder_tokens=[placeholder_token],
591 initializer_tokens=[initializer_token],
592 num_vectors=[num_vectors]
593 )
626 594
627 lr_scheduler = get_scheduler( 595 print(
628 args.lr_scheduler, 596 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
629 optimizer=optimizer, 597
630 num_training_steps_per_epoch=len(datamodule.train_dataloader), 598 args.seed = seed_generator.seed()
631 gradient_accumulation_steps=args.gradient_accumulation_steps, 599
632 min_lr=args.lr_min_lr, 600 datamodule = VlpnDataModule(
633 warmup_func=args.lr_warmup_func, 601 data_file=args.train_data_file,
634 annealing_func=args.lr_annealing_func, 602 batch_size=args.train_batch_size,
635 warmup_exp=args.lr_warmup_exp, 603 tokenizer=tokenizer,
636 annealing_exp=args.lr_annealing_exp, 604 class_subdir=args.class_image_dir,
637 cycles=args.lr_cycles, 605 num_class_images=args.num_class_images,
638 train_epochs=args.num_train_epochs, 606 size=args.resolution,
639 warmup_epochs=args.lr_warmup_epochs, 607 num_buckets=args.num_buckets,
640 ) 608 progressive_buckets=args.progressive_buckets,
641 609 bucket_step_size=args.bucket_step_size,
642 trainer( 610 bucket_max_pixels=args.bucket_max_pixels,
643 project="textual_inversion", 611 dropout=args.tag_dropout,
644 train_dataloader=datamodule.train_dataloader, 612 shuffle=not args.no_tag_shuffle,
645 val_dataloader=datamodule.val_dataloader, 613 template_key=data_template,
646 optimizer=optimizer, 614 valid_set_size=args.valid_set_size,
647 lr_scheduler=lr_scheduler, 615 train_set_pad=args.train_set_pad,
648 num_train_epochs=args.num_train_epochs, 616 valid_set_pad=args.valid_set_pad,
649 sample_frequency=args.sample_frequency, 617 seed=args.seed,
650 checkpoint_frequency=args.checkpoint_frequency, 618 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
651 global_step_offset=global_step_offset, 619 dtype=weight_dtype
652 with_prior_preservation=args.num_class_images != 0, 620 )
653 prior_loss_weight=args.prior_loss_weight, 621 datamodule.setup()
654 # -- 622
655 tokenizer=tokenizer, 623 optimizer = optimizer_class(
656 sample_scheduler=sample_scheduler, 624 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
657 output_dir=output_dir, 625 lr=args.learning_rate,
658 placeholder_tokens=args.placeholder_tokens, 626 betas=(args.adam_beta1, args.adam_beta2),
659 placeholder_token_ids=placeholder_token_ids, 627 weight_decay=args.adam_weight_decay,
660 learning_rate=args.learning_rate, 628 eps=args.adam_epsilon,
661 gradient_checkpointing=args.gradient_checkpointing, 629 amsgrad=args.adam_amsgrad,
662 use_emb_decay=args.use_emb_decay, 630 )
663 emb_decay_target=args.emb_decay_target, 631
664 emb_decay_factor=args.emb_decay_factor, 632 lr_scheduler = get_scheduler(
665 emb_decay_start=args.emb_decay_start, 633 args.lr_scheduler,
666 use_ema=args.use_ema, 634 optimizer=optimizer,
667 ema_inv_gamma=args.ema_inv_gamma, 635 num_training_steps_per_epoch=len(datamodule.train_dataloader),
668 ema_power=args.ema_power, 636 gradient_accumulation_steps=args.gradient_accumulation_steps,
669 ema_max_decay=args.ema_max_decay, 637 min_lr=args.lr_min_lr,
670 sample_batch_size=args.sample_batch_size, 638 warmup_func=args.lr_warmup_func,
671 sample_num_batches=args.sample_batches, 639 annealing_func=args.lr_annealing_func,
672 sample_num_steps=args.sample_steps, 640 warmup_exp=args.lr_warmup_exp,
673 sample_image_size=args.sample_image_size, 641 annealing_exp=args.lr_annealing_exp,
674 ) 642 cycles=args.lr_cycles,
643 train_epochs=args.num_train_epochs,
644 warmup_epochs=args.lr_warmup_epochs,
645 )
646
647 trainer(
648 project="textual_inversion",
649 train_dataloader=datamodule.train_dataloader,
650 val_dataloader=datamodule.val_dataloader,
651 optimizer=optimizer,
652 lr_scheduler=lr_scheduler,
653 num_train_epochs=args.num_train_epochs,
654 sample_frequency=args.sample_frequency,
655 checkpoint_frequency=args.checkpoint_frequency,
656 global_step_offset=global_step_offset,
657 with_prior_preservation=args.num_class_images != 0,
658 prior_loss_weight=args.prior_loss_weight,
659 # --
660 tokenizer=tokenizer,
661 sample_scheduler=sample_scheduler,
662 output_dir=cur_subdir,
663 placeholder_tokens=[placeholder_token],
664 placeholder_token_ids=placeholder_token_ids,
665 learning_rate=args.learning_rate,
666 gradient_checkpointing=args.gradient_checkpointing,
667 use_emb_decay=args.use_emb_decay,
668 emb_decay_target=args.emb_decay_target,
669 emb_decay_factor=args.emb_decay_factor,
670 emb_decay_start=args.emb_decay_start,
671 use_ema=args.use_ema,
672 ema_inv_gamma=args.ema_inv_gamma,
673 ema_power=args.ema_power,
674 ema_max_decay=args.ema_max_decay,
675 sample_batch_size=args.sample_batch_size,
676 sample_num_batches=args.sample_batches,
677 sample_num_steps=args.sample_steps,
678 sample_image_size=args.sample_image_size,
679 )
680
681 embeddings.persist()
675 682
676 683
677if __name__ == "__main__": 684if __name__ == "__main__":