diff options
| -rw-r--r-- | dreambooth.py | 104 | ||||
| -rw-r--r-- | dreambooth_plus.py | 7 | ||||
| -rw-r--r-- | models/clip/prompt.py | 2 |
3 files changed, 76 insertions, 37 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9786e0f..d1cf535 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -1,4 +1,5 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | ||
| 2 | import math | 3 | import math |
| 3 | import os | 4 | import os |
| 4 | import datetime | 5 | import datetime |
| @@ -113,7 +114,7 @@ def parse_args(): | |||
| 113 | parser.add_argument( | 114 | parser.add_argument( |
| 114 | "--max_train_steps", | 115 | "--max_train_steps", |
| 115 | type=int, | 116 | type=int, |
| 116 | default=1200, | 117 | default=3600, |
| 117 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 118 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 118 | ) | 119 | ) |
| 119 | parser.add_argument( | 120 | parser.add_argument( |
| @@ -128,9 +129,15 @@ def parse_args(): | |||
| 128 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 129 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| 129 | ) | 130 | ) |
| 130 | parser.add_argument( | 131 | parser.add_argument( |
| 131 | "--learning_rate", | 132 | "--learning_rate_unet", |
| 132 | type=float, | 133 | type=float, |
| 133 | default=5e-5, | 134 | default=3e-6, |
| 135 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 136 | ) | ||
| 137 | parser.add_argument( | ||
| 138 | "--learning_rate_text", | ||
| 139 | type=float, | ||
| 140 | default=3e-6, | ||
| 134 | help="Initial learning rate (after the potential warmup period) to use.", | 141 | help="Initial learning rate (after the potential warmup period) to use.", |
| 135 | ) | 142 | ) |
| 136 | parser.add_argument( | 143 | parser.add_argument( |
| @@ -358,12 +365,14 @@ class Checkpointer: | |||
| 358 | def save_model(self): | 365 | def save_model(self): |
| 359 | print("Saving model...") | 366 | print("Saving model...") |
| 360 | 367 | ||
| 361 | unwrapped = self.accelerator.unwrap_model( | 368 | unwrapped_unet = self.accelerator.unwrap_model( |
| 362 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 369 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
| 370 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 371 | |||
| 363 | pipeline = VlpnStableDiffusion( | 372 | pipeline = VlpnStableDiffusion( |
| 364 | text_encoder=self.text_encoder, | 373 | text_encoder=unwrapped_text_encoder, |
| 365 | vae=self.vae, | 374 | vae=self.vae, |
| 366 | unet=unwrapped, | 375 | unet=unwrapped_unet, |
| 367 | tokenizer=self.tokenizer, | 376 | tokenizer=self.tokenizer, |
| 368 | scheduler=PNDMScheduler( | 377 | scheduler=PNDMScheduler( |
| 369 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 378 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
| @@ -371,7 +380,8 @@ class Checkpointer: | |||
| 371 | ) | 380 | ) |
| 372 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 381 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
| 373 | 382 | ||
| 374 | del unwrapped | 383 | del unwrapped_unet |
| 384 | del unwrapped_text_encoder | ||
| 375 | del pipeline | 385 | del pipeline |
| 376 | 386 | ||
| 377 | if torch.cuda.is_available(): | 387 | if torch.cuda.is_available(): |
| @@ -381,16 +391,18 @@ class Checkpointer: | |||
| 381 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 391 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 382 | samples_path = Path(self.output_dir).joinpath("samples") | 392 | samples_path = Path(self.output_dir).joinpath("samples") |
| 383 | 393 | ||
| 384 | unwrapped = self.accelerator.unwrap_model( | 394 | unwrapped_unet = self.accelerator.unwrap_model( |
| 385 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 395 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
| 396 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 397 | |||
| 386 | scheduler = EulerAScheduler( | 398 | scheduler = EulerAScheduler( |
| 387 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 399 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 388 | ) | 400 | ) |
| 389 | 401 | ||
| 390 | pipeline = VlpnStableDiffusion( | 402 | pipeline = VlpnStableDiffusion( |
| 391 | text_encoder=self.text_encoder, | 403 | text_encoder=unwrapped_text_encoder, |
| 392 | vae=self.vae, | 404 | vae=self.vae, |
| 393 | unet=unwrapped, | 405 | unet=unwrapped_unet, |
| 394 | tokenizer=self.tokenizer, | 406 | tokenizer=self.tokenizer, |
| 395 | scheduler=scheduler, | 407 | scheduler=scheduler, |
| 396 | ).to(self.accelerator.device) | 408 | ).to(self.accelerator.device) |
| @@ -416,9 +428,16 @@ class Checkpointer: | |||
| 416 | 428 | ||
| 417 | for i in range(self.sample_batches): | 429 | for i in range(self.sample_batches): |
| 418 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 430 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 419 | prompt = [prompt.format(self.instance_identifier) | 431 | prompt = [ |
| 420 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 432 | prompt.format(self.instance_identifier) |
| 421 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 433 | for batch in batches |
| 434 | for prompt in batch["prompts"] | ||
| 435 | ][:self.sample_batch_size] | ||
| 436 | nprompt = [ | ||
| 437 | prompt | ||
| 438 | for batch in batches | ||
| 439 | for prompt in batch["nprompts"] | ||
| 440 | ][:self.sample_batch_size] | ||
| 422 | 441 | ||
| 423 | samples = pipeline( | 442 | samples = pipeline( |
| 424 | prompt=prompt, | 443 | prompt=prompt, |
| @@ -443,7 +462,8 @@ class Checkpointer: | |||
| 443 | del all_samples | 462 | del all_samples |
| 444 | del image_grid | 463 | del image_grid |
| 445 | 464 | ||
| 446 | del unwrapped | 465 | del unwrapped_unet |
| 466 | del unwrapped_text_encoder | ||
| 447 | del scheduler | 467 | del scheduler |
| 448 | del pipeline | 468 | del pipeline |
| 449 | del generator | 469 | del generator |
| @@ -482,8 +502,7 @@ def main(): | |||
| 482 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 502 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 483 | 503 | ||
| 484 | # Load models and create wrapper for stable diffusion | 504 | # Load models and create wrapper for stable diffusion |
| 485 | text_encoder = CLIPTextModel.from_pretrained( | 505 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 486 | args.pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 487 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 506 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 488 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 507 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 489 | 508 | ||
| @@ -499,17 +518,21 @@ def main(): | |||
| 499 | 518 | ||
| 500 | if args.gradient_checkpointing: | 519 | if args.gradient_checkpointing: |
| 501 | unet.enable_gradient_checkpointing() | 520 | unet.enable_gradient_checkpointing() |
| 521 | text_encoder.gradient_checkpointing_enable() | ||
| 502 | 522 | ||
| 503 | # slice_size = unet.config.attention_head_dim // 2 | 523 | # slice_size = unet.config.attention_head_dim // 2 |
| 504 | # unet.set_attention_slice(slice_size) | 524 | # unet.set_attention_slice(slice_size) |
| 505 | 525 | ||
| 506 | # Freeze text_encoder and vae | 526 | # Freeze text_encoder and vae |
| 507 | freeze_params(vae.parameters()) | 527 | freeze_params(vae.parameters()) |
| 508 | freeze_params(text_encoder.parameters()) | ||
| 509 | 528 | ||
| 510 | if args.scale_lr: | 529 | if args.scale_lr: |
| 511 | args.learning_rate = ( | 530 | args.learning_rate_unet = ( |
| 512 | args.learning_rate * args.gradient_accumulation_steps * | 531 | args.learning_rate_unet * args.gradient_accumulation_steps * |
| 532 | args.train_batch_size * accelerator.num_processes | ||
| 533 | ) | ||
| 534 | args.learning_rate_text = ( | ||
| 535 | args.learning_rate_text * args.gradient_accumulation_steps * | ||
| 513 | args.train_batch_size * accelerator.num_processes | 536 | args.train_batch_size * accelerator.num_processes |
| 514 | ) | 537 | ) |
| 515 | 538 | ||
| @@ -526,8 +549,16 @@ def main(): | |||
| 526 | 549 | ||
| 527 | # Initialize the optimizer | 550 | # Initialize the optimizer |
| 528 | optimizer = optimizer_class( | 551 | optimizer = optimizer_class( |
| 529 | unet.parameters(), # only optimize unet | 552 | [ |
| 530 | lr=args.learning_rate, | 553 | { |
| 554 | 'params': unet.parameters(), | ||
| 555 | 'lr': args.learning_rate_unet, | ||
| 556 | }, | ||
| 557 | { | ||
| 558 | 'params': text_encoder.parameters(), | ||
| 559 | 'lr': args.learning_rate_text, | ||
| 560 | } | ||
| 561 | ], | ||
| 531 | betas=(args.adam_beta1, args.adam_beta2), | 562 | betas=(args.adam_beta1, args.adam_beta2), |
| 532 | weight_decay=args.adam_weight_decay, | 563 | weight_decay=args.adam_weight_decay, |
| 533 | eps=args.adam_epsilon, | 564 | eps=args.adam_epsilon, |
| @@ -592,8 +623,10 @@ def main(): | |||
| 592 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 623 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
| 593 | 624 | ||
| 594 | if len(missing_data) != 0: | 625 | if len(missing_data) != 0: |
| 595 | batched_data = [missing_data[i:i+args.sample_batch_size] | 626 | batched_data = [ |
| 596 | for i in range(0, len(missing_data), args.sample_batch_size)] | 627 | missing_data[i:i+args.sample_batch_size] |
| 628 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
| 629 | ] | ||
| 597 | 630 | ||
| 598 | scheduler = EulerAScheduler( | 631 | scheduler = EulerAScheduler( |
| 599 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 632 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| @@ -610,9 +643,9 @@ def main(): | |||
| 610 | 643 | ||
| 611 | with torch.inference_mode(): | 644 | with torch.inference_mode(): |
| 612 | for batch in batched_data: | 645 | for batch in batched_data: |
| 613 | image_name = [p.class_image_path for p in batch] | 646 | image_name = [item.class_image_path for item in batch] |
| 614 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 647 | prompt = [item.prompt.format(args.class_identifier) for item in batch] |
| 615 | nprompt = [p.nprompt for p in batch] | 648 | nprompt = [item.nprompt for item in batch] |
| 616 | 649 | ||
| 617 | images = pipeline( | 650 | images = pipeline( |
| 618 | prompt=prompt, | 651 | prompt=prompt, |
| @@ -655,16 +688,14 @@ def main(): | |||
| 655 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 688 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 656 | ) | 689 | ) |
| 657 | 690 | ||
| 658 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 691 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 659 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 692 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 660 | ) | 693 | ) |
| 661 | 694 | ||
| 662 | # Move text_encoder and vae to device | 695 | # Move text_encoder and vae to device |
| 663 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
| 664 | vae.to(accelerator.device, dtype=weight_dtype) | 696 | vae.to(accelerator.device, dtype=weight_dtype) |
| 665 | 697 | ||
| 666 | # Keep text_encoder and vae in eval mode as we don't train these | 698 | # Keep text_encoder and vae in eval mode as we don't train these |
| 667 | text_encoder.eval() | ||
| 668 | vae.eval() | 699 | vae.eval() |
| 669 | 700 | ||
| 670 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 701 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| @@ -736,12 +767,13 @@ def main(): | |||
| 736 | local_progress_bar.reset() | 767 | local_progress_bar.reset() |
| 737 | 768 | ||
| 738 | unet.train() | 769 | unet.train() |
| 770 | text_encoder.train() | ||
| 739 | train_loss = 0.0 | 771 | train_loss = 0.0 |
| 740 | 772 | ||
| 741 | sample_checkpoint = False | 773 | sample_checkpoint = False |
| 742 | 774 | ||
| 743 | for step, batch in enumerate(train_dataloader): | 775 | for step, batch in enumerate(train_dataloader): |
| 744 | with accelerator.accumulate(unet): | 776 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
| 745 | # Convert images to latent space | 777 | # Convert images to latent space |
| 746 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 778 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 747 | latents = latents * 0.18215 | 779 | latents = latents * 0.18215 |
| @@ -782,7 +814,8 @@ def main(): | |||
| 782 | 814 | ||
| 783 | accelerator.backward(loss) | 815 | accelerator.backward(loss) |
| 784 | if accelerator.sync_gradients: | 816 | if accelerator.sync_gradients: |
| 785 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | 817 | accelerator.clip_grad_norm_(itertools.chain( |
| 818 | unet.parameters(), text_encoder.parameters()), args.max_grad_norm) | ||
| 786 | optimizer.step() | 819 | optimizer.step() |
| 787 | if not accelerator.optimizer_step_was_skipped: | 820 | if not accelerator.optimizer_step_was_skipped: |
| 788 | lr_scheduler.step() | 821 | lr_scheduler.step() |
| @@ -804,7 +837,11 @@ def main(): | |||
| 804 | if global_step % args.sample_frequency == 0: | 837 | if global_step % args.sample_frequency == 0: |
| 805 | sample_checkpoint = True | 838 | sample_checkpoint = True |
| 806 | 839 | ||
| 807 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 840 | logs = { |
| 841 | "train/loss": loss, | ||
| 842 | "lr/unet": lr_scheduler.get_last_lr()[0], | ||
| 843 | "lr/text": lr_scheduler.get_last_lr()[1] | ||
| 844 | } | ||
| 808 | if args.use_ema: | 845 | if args.use_ema: |
| 809 | logs["ema_decay"] = ema_unet.decay | 846 | logs["ema_decay"] = ema_unet.decay |
| 810 | 847 | ||
| @@ -820,6 +857,7 @@ def main(): | |||
| 820 | accelerator.wait_for_everyone() | 857 | accelerator.wait_for_everyone() |
| 821 | 858 | ||
| 822 | unet.eval() | 859 | unet.eval() |
| 860 | text_encoder.eval() | ||
| 823 | val_loss = 0.0 | 861 | val_loss = 0.0 |
| 824 | 862 | ||
| 825 | for step, batch in enumerate(val_dataloader): | 863 | for step, batch in enumerate(val_dataloader): |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 06ff45b..413abe3 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -125,7 +125,7 @@ def parse_args(): | |||
| 125 | parser.add_argument( | 125 | parser.add_argument( |
| 126 | "--max_train_steps", | 126 | "--max_train_steps", |
| 127 | type=int, | 127 | type=int, |
| 128 | default=2400, | 128 | default=4700, |
| 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 130 | ) | 130 | ) |
| 131 | parser.add_argument( | 131 | parser.add_argument( |
| @@ -142,13 +142,13 @@ def parse_args(): | |||
| 142 | parser.add_argument( | 142 | parser.add_argument( |
| 143 | "--learning_rate_unet", | 143 | "--learning_rate_unet", |
| 144 | type=float, | 144 | type=float, |
| 145 | default=5e-6, | 145 | default=2e-6, |
| 146 | help="Initial learning rate (after the potential warmup period) to use.", | 146 | help="Initial learning rate (after the potential warmup period) to use.", |
| 147 | ) | 147 | ) |
| 148 | parser.add_argument( | 148 | parser.add_argument( |
| 149 | "--learning_rate_text", | 149 | "--learning_rate_text", |
| 150 | type=float, | 150 | type=float, |
| 151 | default=5e-6, | 151 | default=2e-6, |
| 152 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
| 153 | ) | 153 | ) |
| 154 | parser.add_argument( | 154 | parser.add_argument( |
| @@ -578,6 +578,7 @@ def main(): | |||
| 578 | 578 | ||
| 579 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
| 580 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
| 581 | text_encoder.gradient_checkpointing_enable() | ||
| 581 | 582 | ||
| 582 | # slice_size = unet.config.attention_head_dim // 2 | 583 | # slice_size = unet.config.attention_head_dim // 2 |
| 583 | # unet.set_attention_slice(slice_size) | 584 | # unet.set_attention_slice(slice_size) |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index c1e3340..259ac44 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | from typing import List, Optional, Union | 1 | from typing import List, Union |
| 2 | 2 | ||
| 3 | import torch | 3 | import torch |
| 4 | 4 | ||
