diff options
| -rw-r--r-- | data/csv.py | 7 | ||||
| -rw-r--r-- | train_dreambooth.py | 147 | ||||
| -rw-r--r-- | training/functional.py | 1 | 
3 files changed, 84 insertions, 71 deletions
| diff --git a/data/csv.py b/data/csv.py index 2b1e202..002fdd2 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -15,6 +15,9 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
| 15 | from models.clip.util import unify_input_ids | 15 | from models.clip.util import unify_input_ids | 
| 16 | 16 | ||
| 17 | 17 | ||
| 18 | cache = {} | ||
| 19 | |||
| 20 | |||
| 18 | interpolations = { | 21 | interpolations = { | 
| 19 | "linear": transforms.InterpolationMode.NEAREST, | 22 | "linear": transforms.InterpolationMode.NEAREST, | 
| 20 | "bilinear": transforms.InterpolationMode.BILINEAR, | 23 | "bilinear": transforms.InterpolationMode.BILINEAR, | 
| @@ -24,9 +27,13 @@ interpolations = { | |||
| 24 | 27 | ||
| 25 | 28 | ||
| 26 | def get_image(path): | 29 | def get_image(path): | 
| 30 | if path in cache: | ||
| 31 | return cache[path] | ||
| 32 | |||
| 27 | image = Image.open(path) | 33 | image = Image.open(path) | 
| 28 | if not image.mode == "RGB": | 34 | if not image.mode == "RGB": | 
| 29 | image = image.convert("RGB") | 35 | image = image.convert("RGB") | 
| 36 | cache[path] = image | ||
| 30 | return image | 37 | return image | 
| 31 | 38 | ||
| 32 | 39 | ||
| diff --git a/train_dreambooth.py b/train_dreambooth.py index 944256c..05777d0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -510,19 +510,6 @@ def main(): | |||
| 510 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 510 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 
| 511 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 511 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 
| 512 | 512 | ||
| 513 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 514 | tokenizer=tokenizer, | ||
| 515 | embeddings=embeddings, | ||
| 516 | placeholder_tokens=args.placeholder_tokens, | ||
| 517 | initializer_tokens=args.initializer_tokens, | ||
| 518 | num_vectors=args.num_vectors | ||
| 519 | ) | ||
| 520 | |||
| 521 | if len(placeholder_token_ids) != 0: | ||
| 522 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] | ||
| 523 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
| 524 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
| 525 | |||
| 526 | if args.scale_lr: | 513 | if args.scale_lr: | 
| 527 | args.learning_rate = ( | 514 | args.learning_rate = ( | 
| 528 | args.learning_rate * args.gradient_accumulation_steps * | 515 | args.learning_rate * args.gradient_accumulation_steps * | 
| @@ -554,79 +541,98 @@ def main(): | |||
| 554 | noise_scheduler=noise_scheduler, | 541 | noise_scheduler=noise_scheduler, | 
| 555 | dtype=weight_dtype, | 542 | dtype=weight_dtype, | 
| 556 | seed=args.seed, | 543 | seed=args.seed, | 
| 557 | callbacks_fn=textual_inversion_strategy | ||
| 558 | ) | 544 | ) | 
| 559 | 545 | ||
| 560 | # Initial TI | 546 | # Initial TI | 
| 561 | 547 | ||
| 562 | print("Phase 1: Textual Inversion") | 548 | print("Phase 1: Textual Inversion") | 
| 563 | 549 | ||
| 550 | ti_lr = 1e-1 | ||
| 551 | ti_train_epochs = 10 | ||
| 552 | |||
| 564 | cur_dir = output_dir.joinpath("1-ti") | 553 | cur_dir = output_dir.joinpath("1-ti") | 
| 565 | cur_dir.mkdir(parents=True, exist_ok=True) | 554 | cur_dir.mkdir(parents=True, exist_ok=True) | 
| 566 | 555 | ||
| 567 | datamodule = VlpnDataModule( | 556 | for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): | 
| 568 | data_file=args.train_data_file, | 557 | print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") | 
| 569 | batch_size=args.train_batch_size, | ||
| 570 | tokenizer=tokenizer, | ||
| 571 | class_subdir=args.class_image_dir, | ||
| 572 | num_class_images=args.num_class_images, | ||
| 573 | size=args.resolution, | ||
| 574 | shuffle=not args.no_tag_shuffle, | ||
| 575 | template_key=args.train_data_template, | ||
| 576 | valid_set_size=args.valid_set_size, | ||
| 577 | valid_set_repeat=args.valid_set_repeat, | ||
| 578 | seed=args.seed, | ||
| 579 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
| 580 | dtype=weight_dtype | ||
| 581 | ) | ||
| 582 | datamodule.setup() | ||
| 583 | 558 | ||
| 584 | optimizer = optimizer_class( | 559 | cur_subdir = cur_dir.joinpath(placeholder_token) | 
| 585 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 560 | cur_subdir.mkdir(parents=True, exist_ok=True) | 
| 586 | lr=2e-1, | ||
| 587 | weight_decay=0.0, | ||
| 588 | ) | ||
| 589 | 561 | ||
| 590 | lr_scheduler = get_scheduler( | 562 | placeholder_token_ids, _ = add_placeholder_tokens( | 
| 591 | "linear", | 563 | tokenizer=tokenizer, | 
| 592 | optimizer=optimizer, | 564 | embeddings=embeddings, | 
| 593 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 565 | placeholder_tokens=[placeholder_token], | 
| 594 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 566 | initializer_tokens=[initializer_token], | 
| 595 | train_epochs=30, | 567 | num_vectors=num_vectors | 
| 596 | warmup_epochs=10, | 568 | ) | 
| 597 | ) | ||
| 598 | 569 | ||
| 599 | trainer( | 570 | datamodule = VlpnDataModule( | 
| 600 | project="textual_inversion", | 571 | data_file=args.train_data_file, | 
| 601 | train_dataloader=datamodule.train_dataloader, | 572 | batch_size=args.train_batch_size, | 
| 602 | val_dataloader=datamodule.val_dataloader, | 573 | tokenizer=tokenizer, | 
| 603 | optimizer=optimizer, | 574 | class_subdir=args.class_image_dir, | 
| 604 | lr_scheduler=lr_scheduler, | 575 | num_class_images=args.num_class_images, | 
| 605 | num_train_epochs=30, | 576 | size=args.resolution, | 
| 606 | sample_frequency=5, | 577 | shuffle=not args.no_tag_shuffle, | 
| 607 | checkpoint_frequency=9999999, | 578 | template_key=args.train_data_template, | 
| 608 | with_prior_preservation=args.num_class_images != 0, | 579 | valid_set_size=args.valid_set_size, | 
| 609 | prior_loss_weight=args.prior_loss_weight, | 580 | valid_set_repeat=args.valid_set_repeat, | 
| 610 | # -- | 581 | seed=args.seed, | 
| 611 | tokenizer=tokenizer, | 582 | filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), | 
| 612 | sample_scheduler=sample_scheduler, | 583 | dtype=weight_dtype | 
| 613 | output_dir=cur_dir, | 584 | ) | 
| 614 | placeholder_tokens=args.placeholder_tokens, | 585 | datamodule.setup() | 
| 615 | placeholder_token_ids=placeholder_token_ids, | 586 | |
| 616 | learning_rate=2e-1, | 587 | optimizer = optimizer_class( | 
| 617 | gradient_checkpointing=args.gradient_checkpointing, | 588 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 
| 618 | use_emb_decay=True, | 589 | lr=ti_lr, | 
| 619 | sample_batch_size=args.sample_batch_size, | 590 | weight_decay=0.0, | 
| 620 | sample_num_batches=args.sample_batches, | 591 | ) | 
| 621 | sample_num_steps=args.sample_steps, | 592 | |
| 622 | sample_image_size=args.sample_image_size, | 593 | lr_scheduler = get_scheduler( | 
| 623 | ) | 594 | "one_cycle", | 
| 595 | optimizer=optimizer, | ||
| 596 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 597 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 598 | train_epochs=ti_train_epochs, | ||
| 599 | warmup_epochs=ti_train_epochs // 4, | ||
| 600 | ) | ||
| 601 | |||
| 602 | trainer( | ||
| 603 | callbacks_fn=textual_inversion_strategy, | ||
| 604 | project="textual_inversion", | ||
| 605 | train_dataloader=datamodule.train_dataloader, | ||
| 606 | val_dataloader=datamodule.val_dataloader, | ||
| 607 | optimizer=optimizer, | ||
| 608 | lr_scheduler=lr_scheduler, | ||
| 609 | num_train_epochs=ti_train_epochs, | ||
| 610 | sample_frequency=1, | ||
| 611 | checkpoint_frequency=9999999, | ||
| 612 | with_prior_preservation=args.num_class_images != 0, | ||
| 613 | prior_loss_weight=args.prior_loss_weight, | ||
| 614 | # -- | ||
| 615 | tokenizer=tokenizer, | ||
| 616 | sample_scheduler=sample_scheduler, | ||
| 617 | output_dir=cur_subdir, | ||
| 618 | placeholder_tokens=[placeholder_token], | ||
| 619 | placeholder_token_ids=placeholder_token_ids, | ||
| 620 | learning_rate=ti_lr, | ||
| 621 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 622 | use_emb_decay=True, | ||
| 623 | sample_batch_size=args.sample_batch_size, | ||
| 624 | sample_num_batches=args.sample_batches, | ||
| 625 | sample_num_steps=args.sample_steps, | ||
| 626 | sample_image_size=args.sample_image_size, | ||
| 627 | ) | ||
| 628 | |||
| 629 | embeddings.persist() | ||
| 624 | 630 | ||
| 625 | # Dreambooth | 631 | # Dreambooth | 
| 626 | 632 | ||
| 627 | print("Phase 2: Dreambooth") | 633 | print("Phase 2: Dreambooth") | 
| 628 | 634 | ||
| 629 | cur_dir = output_dir.joinpath("2db") | 635 | cur_dir = output_dir.joinpath("2-db") | 
| 630 | cur_dir.mkdir(parents=True, exist_ok=True) | 636 | cur_dir.mkdir(parents=True, exist_ok=True) | 
| 631 | 637 | ||
| 632 | args.seed = (args.seed + 28635) >> 32 | 638 | args.seed = (args.seed + 28635) >> 32 | 
| @@ -685,6 +691,7 @@ def main(): | |||
| 685 | ) | 691 | ) | 
| 686 | 692 | ||
| 687 | trainer( | 693 | trainer( | 
| 694 | callbacks_fn=dreambooth_strategy, | ||
| 688 | project="dreambooth", | 695 | project="dreambooth", | 
| 689 | train_dataloader=datamodule.train_dataloader, | 696 | train_dataloader=datamodule.train_dataloader, | 
| 690 | val_dataloader=datamodule.val_dataloader, | 697 | val_dataloader=datamodule.val_dataloader, | 
| @@ -692,14 +699,12 @@ def main(): | |||
| 692 | lr_scheduler=lr_scheduler, | 699 | lr_scheduler=lr_scheduler, | 
| 693 | num_train_epochs=args.num_train_epochs, | 700 | num_train_epochs=args.num_train_epochs, | 
| 694 | sample_frequency=args.sample_frequency, | 701 | sample_frequency=args.sample_frequency, | 
| 695 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 696 | with_prior_preservation=args.num_class_images != 0, | 702 | with_prior_preservation=args.num_class_images != 0, | 
| 697 | prior_loss_weight=args.prior_loss_weight, | 703 | prior_loss_weight=args.prior_loss_weight, | 
| 698 | # -- | 704 | # -- | 
| 699 | tokenizer=tokenizer, | 705 | tokenizer=tokenizer, | 
| 700 | sample_scheduler=sample_scheduler, | 706 | sample_scheduler=sample_scheduler, | 
| 701 | output_dir=cur_dir, | 707 | output_dir=cur_dir, | 
| 702 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 703 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 708 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 
| 704 | max_grad_norm=args.max_grad_norm, | 709 | max_grad_norm=args.max_grad_norm, | 
| 705 | use_ema=args.use_ema, | 710 | use_ema=args.use_ema, | 
| diff --git a/training/functional.py b/training/functional.py index f5c111e..1b6162b 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -486,6 +486,7 @@ def train_loop( | |||
| 486 | if accelerator.is_main_process: | 486 | if accelerator.is_main_process: | 
| 487 | print("Interrupted") | 487 | print("Interrupted") | 
| 488 | on_checkpoint(global_step + global_step_offset, "end") | 488 | on_checkpoint(global_step + global_step_offset, "end") | 
| 489 | raise KeyboardInterrupt | ||
| 489 | 490 | ||
| 490 | 491 | ||
| 491 | def train( | 492 | def train( | 
