summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py2
-rw-r--r--train_dreambooth.py135
-rw-r--r--train_ti.py221
-rw-r--r--training/functional.py17
4 files changed, 131 insertions, 244 deletions
diff --git a/data/csv.py b/data/csv.py
index 85b98f8..6857b6f 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -282,7 +282,7 @@ class VlpnDataModule():
282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
283 283
284 if valid_set_size == 0: 284 if valid_set_size == 0:
285 data_train, data_val = items, items[:1] 285 data_train, data_val = items, []
286 else: 286 else:
287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
288 288
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1dc41b1..6511f9b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -200,23 +200,6 @@ def parse_args():
200 default=100 200 default=100
201 ) 201 )
202 parser.add_argument( 202 parser.add_argument(
203 "--ti_data_template",
204 type=str,
205 nargs='*',
206 default=[],
207 )
208 parser.add_argument(
209 "--ti_num_train_epochs",
210 type=int,
211 default=10
212 )
213 parser.add_argument(
214 "--ti_batch_size",
215 type=int,
216 default=1,
217 help="Batch size (per device) for the training dataloader."
218 )
219 parser.add_argument(
220 "--max_train_steps", 203 "--max_train_steps",
221 type=int, 204 type=int,
222 default=None, 205 default=None,
@@ -245,12 +228,6 @@ def parse_args():
245 help="Initial learning rate (after the potential warmup period) to use.", 228 help="Initial learning rate (after the potential warmup period) to use.",
246 ) 229 )
247 parser.add_argument( 230 parser.add_argument(
248 "--ti_learning_rate",
249 type=float,
250 default=1e-2,
251 help="Initial learning rate (after the potential warmup period) to use.",
252 )
253 parser.add_argument(
254 "--scale_lr", 231 "--scale_lr",
255 action="store_true", 232 action="store_true",
256 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 233 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
@@ -482,12 +459,6 @@ def parse_args():
482 if len(args.placeholder_tokens) != len(args.num_vectors): 459 if len(args.placeholder_tokens) != len(args.num_vectors):
483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 460 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
484 461
485 if isinstance(args.ti_data_template, str):
486 args.ti_data_template = [args.ti_data_template]
487
488 if len(args.ti_data_template) == 0:
489 raise ValueError("You must specify --ti_data_template")
490
491 if isinstance(args.collection, str): 462 if isinstance(args.collection, str):
492 args.collection = [args.collection] 463 args.collection = [args.collection]
493 464
@@ -521,8 +492,6 @@ def main():
521 492
522 set_seed(args.seed) 493 set_seed(args.seed)
523 494
524 seed_generator = torch.Generator().manual_seed(args.seed)
525
526 save_args(output_dir, args) 495 save_args(output_dir, args)
527 496
528 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 497 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -583,107 +552,6 @@ def main():
583 prior_loss_weight=args.prior_loss_weight, 552 prior_loss_weight=args.prior_loss_weight,
584 ) 553 )
585 554
586 # Initial TI
587
588 print("Phase 1: Textual Inversion")
589
590 cur_dir = output_dir.joinpath("1-ti")
591 cur_dir.mkdir(parents=True, exist_ok=True)
592
593 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
594 range(len(args.placeholder_tokens)),
595 args.placeholder_tokens,
596 args.initializer_tokens,
597 args.num_vectors,
598 args.ti_data_template
599 ):
600 cur_subdir = cur_dir.joinpath(placeholder_token)
601 cur_subdir.mkdir(parents=True, exist_ok=True)
602
603 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
604 tokenizer=tokenizer,
605 embeddings=embeddings,
606 placeholder_tokens=[placeholder_token],
607 initializer_tokens=[initializer_token],
608 num_vectors=[num_vectors]
609 )
610
611 print(
612 f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
613
614 args.seed = seed_generator.seed()
615
616 datamodule = VlpnDataModule(
617 data_file=args.train_data_file,
618 batch_size=args.ti_batch_size,
619 tokenizer=tokenizer,
620 class_subdir=args.class_image_dir,
621 num_class_images=args.num_class_images,
622 size=args.resolution,
623 shuffle=not args.no_tag_shuffle,
624 template_key=data_template,
625 valid_set_size=1,
626 train_set_pad=args.train_set_pad,
627 valid_set_pad=args.valid_set_pad,
628 seed=args.seed,
629 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
630 dtype=weight_dtype
631 )
632 datamodule.setup()
633
634 optimizer = optimizer_class(
635 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
636 lr=args.ti_learning_rate,
637 betas=(args.adam_beta1, args.adam_beta2),
638 weight_decay=0.0,
639 eps=args.adam_epsilon,
640 )
641
642 lr_scheduler = get_scheduler(
643 "one_cycle",
644 optimizer=optimizer,
645 num_training_steps_per_epoch=len(datamodule.train_dataloader),
646 gradient_accumulation_steps=args.gradient_accumulation_steps,
647 train_epochs=args.ti_num_train_epochs,
648 )
649
650 trainer(
651 callbacks_fn=textual_inversion_strategy,
652 project="textual_inversion",
653 train_dataloader=datamodule.train_dataloader,
654 val_dataloader=datamodule.val_dataloader,
655 seed=args.seed,
656 optimizer=optimizer,
657 lr_scheduler=lr_scheduler,
658 num_train_epochs=args.ti_num_train_epochs,
659 sample_frequency=args.ti_num_train_epochs // 5,
660 checkpoint_frequency=9999999,
661 # --
662 tokenizer=tokenizer,
663 sample_scheduler=sample_scheduler,
664 output_dir=cur_subdir,
665 placeholder_tokens=[placeholder_token],
666 placeholder_token_ids=placeholder_token_ids,
667 learning_rate=args.ti_learning_rate,
668 gradient_checkpointing=args.gradient_checkpointing,
669 use_emb_decay=True,
670 sample_batch_size=args.sample_batch_size,
671 sample_num_batches=args.sample_batches,
672 sample_num_steps=args.sample_steps,
673 sample_image_size=args.sample_image_size,
674 )
675
676 embeddings.persist()
677
678 # Dreambooth
679
680 print("Phase 2: Dreambooth")
681
682 cur_dir = output_dir.joinpath("2-db")
683 cur_dir.mkdir(parents=True, exist_ok=True)
684
685 args.seed = seed_generator.seed()
686
687 datamodule = VlpnDataModule( 555 datamodule = VlpnDataModule(
688 data_file=args.train_data_file, 556 data_file=args.train_data_file,
689 batch_size=args.train_batch_size, 557 batch_size=args.train_batch_size,
@@ -746,12 +614,13 @@ def main():
746 seed=args.seed, 614 seed=args.seed,
747 optimizer=optimizer, 615 optimizer=optimizer,
748 lr_scheduler=lr_scheduler, 616 lr_scheduler=lr_scheduler,
617 prepare_unet=True,
749 num_train_epochs=args.num_train_epochs, 618 num_train_epochs=args.num_train_epochs,
750 sample_frequency=args.sample_frequency, 619 sample_frequency=args.sample_frequency,
751 # -- 620 # --
752 tokenizer=tokenizer, 621 tokenizer=tokenizer,
753 sample_scheduler=sample_scheduler, 622 sample_scheduler=sample_scheduler,
754 output_dir=cur_dir, 623 output_dir=output_dir,
755 train_text_encoder_epochs=args.train_text_encoder_epochs, 624 train_text_encoder_epochs=args.train_text_encoder_epochs,
756 max_grad_norm=args.max_grad_norm, 625 max_grad_norm=args.max_grad_norm,
757 use_ema=args.use_ema, 626 use_ema=args.use_ema,
diff --git a/train_ti.py b/train_ti.py
index 7aecdef..adba8d4 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -51,6 +51,7 @@ def parse_args():
51 parser.add_argument( 51 parser.add_argument(
52 "--train_data_template", 52 "--train_data_template",
53 type=str, 53 type=str,
54 nargs='*',
54 default="template", 55 default="template",
55 ) 56 )
56 parser.add_argument( 57 parser.add_argument(
@@ -468,11 +469,17 @@ def parse_args():
468 args.num_vectors = 1 469 args.num_vectors = 1
469 470
470 if isinstance(args.num_vectors, int): 471 if isinstance(args.num_vectors, int):
471 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) 472 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
472 473
473 if len(args.placeholder_tokens) != len(args.num_vectors): 474 if len(args.placeholder_tokens) != len(args.num_vectors):
474 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 475 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
475 476
477 if isinstance(args.train_data_template, str):
478 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
479
480 if len(args.placeholder_tokens) != len(args.train_data_template):
481 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
482
476 if isinstance(args.collection, str): 483 if isinstance(args.collection, str):
477 args.collection = [args.collection] 484 args.collection = [args.collection]
478 485
@@ -507,6 +514,8 @@ def main():
507 514
508 set_seed(args.seed) 515 set_seed(args.seed)
509 516
517 seed_generator = torch.Generator().manual_seed(args.seed)
518
510 save_args(output_dir, args) 519 save_args(output_dir, args)
511 520
512 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 521 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -531,19 +540,6 @@ def main():
531 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 540 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
532 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 541 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
533 542
534 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
535 tokenizer=tokenizer,
536 embeddings=embeddings,
537 placeholder_tokens=args.placeholder_tokens,
538 initializer_tokens=args.initializer_tokens,
539 num_vectors=args.num_vectors
540 )
541
542 if len(placeholder_token_ids) != 0:
543 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
544 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
545 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
546
547 if args.scale_lr: 543 if args.scale_lr:
548 args.learning_rate = ( 544 args.learning_rate = (
549 args.learning_rate * args.gradient_accumulation_steps * 545 args.learning_rate * args.gradient_accumulation_steps *
@@ -566,43 +562,6 @@ def main():
566 elif args.mixed_precision == "bf16": 562 elif args.mixed_precision == "bf16":
567 weight_dtype = torch.bfloat16 563 weight_dtype = torch.bfloat16
568 564
569 datamodule = VlpnDataModule(
570 data_file=args.train_data_file,
571 batch_size=args.train_batch_size,
572 tokenizer=tokenizer,
573 class_subdir=args.class_image_dir,
574 num_class_images=args.num_class_images,
575 size=args.resolution,
576 num_buckets=args.num_buckets,
577 progressive_buckets=args.progressive_buckets,
578 bucket_step_size=args.bucket_step_size,
579 bucket_max_pixels=args.bucket_max_pixels,
580 dropout=args.tag_dropout,
581 shuffle=not args.no_tag_shuffle,
582 template_key=args.train_data_template,
583 valid_set_size=args.valid_set_size,
584 train_set_pad=args.train_set_pad,
585 valid_set_pad=args.valid_set_pad,
586 seed=args.seed,
587 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
588 dtype=weight_dtype
589 )
590 datamodule.setup()
591
592 if args.num_class_images != 0:
593 generate_class_images(
594 accelerator,
595 text_encoder,
596 vae,
597 unet,
598 tokenizer,
599 sample_scheduler,
600 datamodule.train_dataset,
601 args.sample_batch_size,
602 args.sample_image_size,
603 args.sample_steps
604 )
605
606 trainer = partial( 565 trainer = partial(
607 train, 566 train,
608 accelerator=accelerator, 567 accelerator=accelerator,
@@ -615,63 +574,111 @@ def main():
615 callbacks_fn=textual_inversion_strategy 574 callbacks_fn=textual_inversion_strategy
616 ) 575 )
617 576
618 optimizer = optimizer_class( 577 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
619 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 578 range(len(args.placeholder_tokens)),
620 lr=args.learning_rate, 579 args.placeholder_tokens,
621 betas=(args.adam_beta1, args.adam_beta2), 580 args.initializer_tokens,
622 weight_decay=args.adam_weight_decay, 581 args.num_vectors,
623 eps=args.adam_epsilon, 582 args.train_data_template
624 amsgrad=args.adam_amsgrad, 583 ):
625 ) 584 cur_subdir = output_dir.joinpath(placeholder_token)
585 cur_subdir.mkdir(parents=True, exist_ok=True)
586
587 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
588 tokenizer=tokenizer,
589 embeddings=embeddings,
590 placeholder_tokens=[placeholder_token],
591 initializer_tokens=[initializer_token],
592 num_vectors=[num_vectors]
593 )
626 594
627 lr_scheduler = get_scheduler( 595 print(
628 args.lr_scheduler, 596 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
629 optimizer=optimizer, 597
630 num_training_steps_per_epoch=len(datamodule.train_dataloader), 598 args.seed = seed_generator.seed()
631 gradient_accumulation_steps=args.gradient_accumulation_steps, 599
632 min_lr=args.lr_min_lr, 600 datamodule = VlpnDataModule(
633 warmup_func=args.lr_warmup_func, 601 data_file=args.train_data_file,
634 annealing_func=args.lr_annealing_func, 602 batch_size=args.train_batch_size,
635 warmup_exp=args.lr_warmup_exp, 603 tokenizer=tokenizer,
636 annealing_exp=args.lr_annealing_exp, 604 class_subdir=args.class_image_dir,
637 cycles=args.lr_cycles, 605 num_class_images=args.num_class_images,
638 train_epochs=args.num_train_epochs, 606 size=args.resolution,
639 warmup_epochs=args.lr_warmup_epochs, 607 num_buckets=args.num_buckets,
640 ) 608 progressive_buckets=args.progressive_buckets,
641 609 bucket_step_size=args.bucket_step_size,
642 trainer( 610 bucket_max_pixels=args.bucket_max_pixels,
643 project="textual_inversion", 611 dropout=args.tag_dropout,
644 train_dataloader=datamodule.train_dataloader, 612 shuffle=not args.no_tag_shuffle,
645 val_dataloader=datamodule.val_dataloader, 613 template_key=data_template,
646 optimizer=optimizer, 614 valid_set_size=args.valid_set_size,
647 lr_scheduler=lr_scheduler, 615 train_set_pad=args.train_set_pad,
648 num_train_epochs=args.num_train_epochs, 616 valid_set_pad=args.valid_set_pad,
649 sample_frequency=args.sample_frequency, 617 seed=args.seed,
650 checkpoint_frequency=args.checkpoint_frequency, 618 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
651 global_step_offset=global_step_offset, 619 dtype=weight_dtype
652 with_prior_preservation=args.num_class_images != 0, 620 )
653 prior_loss_weight=args.prior_loss_weight, 621 datamodule.setup()
654 # -- 622
655 tokenizer=tokenizer, 623 optimizer = optimizer_class(
656 sample_scheduler=sample_scheduler, 624 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
657 output_dir=output_dir, 625 lr=args.learning_rate,
658 placeholder_tokens=args.placeholder_tokens, 626 betas=(args.adam_beta1, args.adam_beta2),
659 placeholder_token_ids=placeholder_token_ids, 627 weight_decay=args.adam_weight_decay,
660 learning_rate=args.learning_rate, 628 eps=args.adam_epsilon,
661 gradient_checkpointing=args.gradient_checkpointing, 629 amsgrad=args.adam_amsgrad,
662 use_emb_decay=args.use_emb_decay, 630 )
663 emb_decay_target=args.emb_decay_target, 631
664 emb_decay_factor=args.emb_decay_factor, 632 lr_scheduler = get_scheduler(
665 emb_decay_start=args.emb_decay_start, 633 args.lr_scheduler,
666 use_ema=args.use_ema, 634 optimizer=optimizer,
667 ema_inv_gamma=args.ema_inv_gamma, 635 num_training_steps_per_epoch=len(datamodule.train_dataloader),
668 ema_power=args.ema_power, 636 gradient_accumulation_steps=args.gradient_accumulation_steps,
669 ema_max_decay=args.ema_max_decay, 637 min_lr=args.lr_min_lr,
670 sample_batch_size=args.sample_batch_size, 638 warmup_func=args.lr_warmup_func,
671 sample_num_batches=args.sample_batches, 639 annealing_func=args.lr_annealing_func,
672 sample_num_steps=args.sample_steps, 640 warmup_exp=args.lr_warmup_exp,
673 sample_image_size=args.sample_image_size, 641 annealing_exp=args.lr_annealing_exp,
674 ) 642 cycles=args.lr_cycles,
643 train_epochs=args.num_train_epochs,
644 warmup_epochs=args.lr_warmup_epochs,
645 )
646
647 trainer(
648 project="textual_inversion",
649 train_dataloader=datamodule.train_dataloader,
650 val_dataloader=datamodule.val_dataloader,
651 optimizer=optimizer,
652 lr_scheduler=lr_scheduler,
653 num_train_epochs=args.num_train_epochs,
654 sample_frequency=args.sample_frequency,
655 checkpoint_frequency=args.checkpoint_frequency,
656 global_step_offset=global_step_offset,
657 with_prior_preservation=args.num_class_images != 0,
658 prior_loss_weight=args.prior_loss_weight,
659 # --
660 tokenizer=tokenizer,
661 sample_scheduler=sample_scheduler,
662 output_dir=cur_subdir,
663 placeholder_tokens=[placeholder_token],
664 placeholder_token_ids=placeholder_token_ids,
665 learning_rate=args.learning_rate,
666 gradient_checkpointing=args.gradient_checkpointing,
667 use_emb_decay=args.use_emb_decay,
668 emb_decay_target=args.emb_decay_target,
669 emb_decay_factor=args.emb_decay_factor,
670 emb_decay_start=args.emb_decay_start,
671 use_ema=args.use_ema,
672 ema_inv_gamma=args.ema_inv_gamma,
673 ema_power=args.ema_power,
674 ema_max_decay=args.ema_max_decay,
675 sample_batch_size=args.sample_batch_size,
676 sample_num_batches=args.sample_batches,
677 sample_num_steps=args.sample_steps,
678 sample_image_size=args.sample_image_size,
679 )
680
681 embeddings.persist()
675 682
676 683
677if __name__ == "__main__": 684if __name__ == "__main__":
diff --git a/training/functional.py b/training/functional.py
index b6b5d87..1548784 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -515,6 +515,7 @@ def train(
515 optimizer: torch.optim.Optimizer, 515 optimizer: torch.optim.Optimizer,
516 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 516 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
517 callbacks_fn: Callable[..., TrainingCallbacks], 517 callbacks_fn: Callable[..., TrainingCallbacks],
518 prepare_unet: bool = False,
518 num_train_epochs: int = 100, 519 num_train_epochs: int = 100,
519 sample_frequency: int = 20, 520 sample_frequency: int = 20,
520 checkpoint_frequency: int = 50, 521 checkpoint_frequency: int = 50,
@@ -523,9 +524,19 @@ def train(
523 prior_loss_weight: float = 1.0, 524 prior_loss_weight: float = 1.0,
524 **kwargs, 525 **kwargs,
525): 526):
526 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 527 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler]
527 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 528
528 ) 529 if prepare_unet:
530 prep.append(unet)
531
532 prep = accelerator.prepare(*prep)
533
534 if prepare_unet:
535 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
536 else:
537 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
538
539 unet.to(accelerator.device, dtype=dtype)
529 540
530 vae.to(accelerator.device, dtype=dtype) 541 vae.to(accelerator.device, dtype=dtype)
531 542