diff options
-rw-r--r-- | dreambooth.py | 104 | ||||
-rw-r--r-- | dreambooth_plus.py | 7 | ||||
-rw-r--r-- | models/clip/prompt.py | 2 |
3 files changed, 76 insertions, 37 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9786e0f..d1cf535 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -1,4 +1,5 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | ||
2 | import math | 3 | import math |
3 | import os | 4 | import os |
4 | import datetime | 5 | import datetime |
@@ -113,7 +114,7 @@ def parse_args(): | |||
113 | parser.add_argument( | 114 | parser.add_argument( |
114 | "--max_train_steps", | 115 | "--max_train_steps", |
115 | type=int, | 116 | type=int, |
116 | default=1200, | 117 | default=3600, |
117 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 118 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
118 | ) | 119 | ) |
119 | parser.add_argument( | 120 | parser.add_argument( |
@@ -128,9 +129,15 @@ def parse_args(): | |||
128 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 129 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
129 | ) | 130 | ) |
130 | parser.add_argument( | 131 | parser.add_argument( |
131 | "--learning_rate", | 132 | "--learning_rate_unet", |
132 | type=float, | 133 | type=float, |
133 | default=5e-5, | 134 | default=3e-6, |
135 | help="Initial learning rate (after the potential warmup period) to use.", | ||
136 | ) | ||
137 | parser.add_argument( | ||
138 | "--learning_rate_text", | ||
139 | type=float, | ||
140 | default=3e-6, | ||
134 | help="Initial learning rate (after the potential warmup period) to use.", | 141 | help="Initial learning rate (after the potential warmup period) to use.", |
135 | ) | 142 | ) |
136 | parser.add_argument( | 143 | parser.add_argument( |
@@ -358,12 +365,14 @@ class Checkpointer: | |||
358 | def save_model(self): | 365 | def save_model(self): |
359 | print("Saving model...") | 366 | print("Saving model...") |
360 | 367 | ||
361 | unwrapped = self.accelerator.unwrap_model( | 368 | unwrapped_unet = self.accelerator.unwrap_model( |
362 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 369 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
370 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
371 | |||
363 | pipeline = VlpnStableDiffusion( | 372 | pipeline = VlpnStableDiffusion( |
364 | text_encoder=self.text_encoder, | 373 | text_encoder=unwrapped_text_encoder, |
365 | vae=self.vae, | 374 | vae=self.vae, |
366 | unet=unwrapped, | 375 | unet=unwrapped_unet, |
367 | tokenizer=self.tokenizer, | 376 | tokenizer=self.tokenizer, |
368 | scheduler=PNDMScheduler( | 377 | scheduler=PNDMScheduler( |
369 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 378 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
@@ -371,7 +380,8 @@ class Checkpointer: | |||
371 | ) | 380 | ) |
372 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 381 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
373 | 382 | ||
374 | del unwrapped | 383 | del unwrapped_unet |
384 | del unwrapped_text_encoder | ||
375 | del pipeline | 385 | del pipeline |
376 | 386 | ||
377 | if torch.cuda.is_available(): | 387 | if torch.cuda.is_available(): |
@@ -381,16 +391,18 @@ class Checkpointer: | |||
381 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 391 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
382 | samples_path = Path(self.output_dir).joinpath("samples") | 392 | samples_path = Path(self.output_dir).joinpath("samples") |
383 | 393 | ||
384 | unwrapped = self.accelerator.unwrap_model( | 394 | unwrapped_unet = self.accelerator.unwrap_model( |
385 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 395 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
396 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
397 | |||
386 | scheduler = EulerAScheduler( | 398 | scheduler = EulerAScheduler( |
387 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 399 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
388 | ) | 400 | ) |
389 | 401 | ||
390 | pipeline = VlpnStableDiffusion( | 402 | pipeline = VlpnStableDiffusion( |
391 | text_encoder=self.text_encoder, | 403 | text_encoder=unwrapped_text_encoder, |
392 | vae=self.vae, | 404 | vae=self.vae, |
393 | unet=unwrapped, | 405 | unet=unwrapped_unet, |
394 | tokenizer=self.tokenizer, | 406 | tokenizer=self.tokenizer, |
395 | scheduler=scheduler, | 407 | scheduler=scheduler, |
396 | ).to(self.accelerator.device) | 408 | ).to(self.accelerator.device) |
@@ -416,9 +428,16 @@ class Checkpointer: | |||
416 | 428 | ||
417 | for i in range(self.sample_batches): | 429 | for i in range(self.sample_batches): |
418 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 430 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
419 | prompt = [prompt.format(self.instance_identifier) | 431 | prompt = [ |
420 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 432 | prompt.format(self.instance_identifier) |
421 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 433 | for batch in batches |
434 | for prompt in batch["prompts"] | ||
435 | ][:self.sample_batch_size] | ||
436 | nprompt = [ | ||
437 | prompt | ||
438 | for batch in batches | ||
439 | for prompt in batch["nprompts"] | ||
440 | ][:self.sample_batch_size] | ||
422 | 441 | ||
423 | samples = pipeline( | 442 | samples = pipeline( |
424 | prompt=prompt, | 443 | prompt=prompt, |
@@ -443,7 +462,8 @@ class Checkpointer: | |||
443 | del all_samples | 462 | del all_samples |
444 | del image_grid | 463 | del image_grid |
445 | 464 | ||
446 | del unwrapped | 465 | del unwrapped_unet |
466 | del unwrapped_text_encoder | ||
447 | del scheduler | 467 | del scheduler |
448 | del pipeline | 468 | del pipeline |
449 | del generator | 469 | del generator |
@@ -482,8 +502,7 @@ def main(): | |||
482 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 502 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
483 | 503 | ||
484 | # Load models and create wrapper for stable diffusion | 504 | # Load models and create wrapper for stable diffusion |
485 | text_encoder = CLIPTextModel.from_pretrained( | 505 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
486 | args.pretrained_model_name_or_path, subfolder='text_encoder') | ||
487 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 506 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
488 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 507 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
489 | 508 | ||
@@ -499,17 +518,21 @@ def main(): | |||
499 | 518 | ||
500 | if args.gradient_checkpointing: | 519 | if args.gradient_checkpointing: |
501 | unet.enable_gradient_checkpointing() | 520 | unet.enable_gradient_checkpointing() |
521 | text_encoder.gradient_checkpointing_enable() | ||
502 | 522 | ||
503 | # slice_size = unet.config.attention_head_dim // 2 | 523 | # slice_size = unet.config.attention_head_dim // 2 |
504 | # unet.set_attention_slice(slice_size) | 524 | # unet.set_attention_slice(slice_size) |
505 | 525 | ||
506 | # Freeze text_encoder and vae | 526 | # Freeze text_encoder and vae |
507 | freeze_params(vae.parameters()) | 527 | freeze_params(vae.parameters()) |
508 | freeze_params(text_encoder.parameters()) | ||
509 | 528 | ||
510 | if args.scale_lr: | 529 | if args.scale_lr: |
511 | args.learning_rate = ( | 530 | args.learning_rate_unet = ( |
512 | args.learning_rate * args.gradient_accumulation_steps * | 531 | args.learning_rate_unet * args.gradient_accumulation_steps * |
532 | args.train_batch_size * accelerator.num_processes | ||
533 | ) | ||
534 | args.learning_rate_text = ( | ||
535 | args.learning_rate_text * args.gradient_accumulation_steps * | ||
513 | args.train_batch_size * accelerator.num_processes | 536 | args.train_batch_size * accelerator.num_processes |
514 | ) | 537 | ) |
515 | 538 | ||
@@ -526,8 +549,16 @@ def main(): | |||
526 | 549 | ||
527 | # Initialize the optimizer | 550 | # Initialize the optimizer |
528 | optimizer = optimizer_class( | 551 | optimizer = optimizer_class( |
529 | unet.parameters(), # only optimize unet | 552 | [ |
530 | lr=args.learning_rate, | 553 | { |
554 | 'params': unet.parameters(), | ||
555 | 'lr': args.learning_rate_unet, | ||
556 | }, | ||
557 | { | ||
558 | 'params': text_encoder.parameters(), | ||
559 | 'lr': args.learning_rate_text, | ||
560 | } | ||
561 | ], | ||
531 | betas=(args.adam_beta1, args.adam_beta2), | 562 | betas=(args.adam_beta1, args.adam_beta2), |
532 | weight_decay=args.adam_weight_decay, | 563 | weight_decay=args.adam_weight_decay, |
533 | eps=args.adam_epsilon, | 564 | eps=args.adam_epsilon, |
@@ -592,8 +623,10 @@ def main(): | |||
592 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 623 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
593 | 624 | ||
594 | if len(missing_data) != 0: | 625 | if len(missing_data) != 0: |
595 | batched_data = [missing_data[i:i+args.sample_batch_size] | 626 | batched_data = [ |
596 | for i in range(0, len(missing_data), args.sample_batch_size)] | 627 | missing_data[i:i+args.sample_batch_size] |
628 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
629 | ] | ||
597 | 630 | ||
598 | scheduler = EulerAScheduler( | 631 | scheduler = EulerAScheduler( |
599 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 632 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
@@ -610,9 +643,9 @@ def main(): | |||
610 | 643 | ||
611 | with torch.inference_mode(): | 644 | with torch.inference_mode(): |
612 | for batch in batched_data: | 645 | for batch in batched_data: |
613 | image_name = [p.class_image_path for p in batch] | 646 | image_name = [item.class_image_path for item in batch] |
614 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 647 | prompt = [item.prompt.format(args.class_identifier) for item in batch] |
615 | nprompt = [p.nprompt for p in batch] | 648 | nprompt = [item.nprompt for item in batch] |
616 | 649 | ||
617 | images = pipeline( | 650 | images = pipeline( |
618 | prompt=prompt, | 651 | prompt=prompt, |
@@ -655,16 +688,14 @@ def main(): | |||
655 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 688 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
656 | ) | 689 | ) |
657 | 690 | ||
658 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 691 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
659 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 692 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
660 | ) | 693 | ) |
661 | 694 | ||
662 | # Move text_encoder and vae to device | 695 | # Move text_encoder and vae to device |
663 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
664 | vae.to(accelerator.device, dtype=weight_dtype) | 696 | vae.to(accelerator.device, dtype=weight_dtype) |
665 | 697 | ||
666 | # Keep text_encoder and vae in eval mode as we don't train these | 698 | # Keep text_encoder and vae in eval mode as we don't train these |
667 | text_encoder.eval() | ||
668 | vae.eval() | 699 | vae.eval() |
669 | 700 | ||
670 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 701 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
@@ -736,12 +767,13 @@ def main(): | |||
736 | local_progress_bar.reset() | 767 | local_progress_bar.reset() |
737 | 768 | ||
738 | unet.train() | 769 | unet.train() |
770 | text_encoder.train() | ||
739 | train_loss = 0.0 | 771 | train_loss = 0.0 |
740 | 772 | ||
741 | sample_checkpoint = False | 773 | sample_checkpoint = False |
742 | 774 | ||
743 | for step, batch in enumerate(train_dataloader): | 775 | for step, batch in enumerate(train_dataloader): |
744 | with accelerator.accumulate(unet): | 776 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
745 | # Convert images to latent space | 777 | # Convert images to latent space |
746 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 778 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
747 | latents = latents * 0.18215 | 779 | latents = latents * 0.18215 |
@@ -782,7 +814,8 @@ def main(): | |||
782 | 814 | ||
783 | accelerator.backward(loss) | 815 | accelerator.backward(loss) |
784 | if accelerator.sync_gradients: | 816 | if accelerator.sync_gradients: |
785 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | 817 | accelerator.clip_grad_norm_(itertools.chain( |
818 | unet.parameters(), text_encoder.parameters()), args.max_grad_norm) | ||
786 | optimizer.step() | 819 | optimizer.step() |
787 | if not accelerator.optimizer_step_was_skipped: | 820 | if not accelerator.optimizer_step_was_skipped: |
788 | lr_scheduler.step() | 821 | lr_scheduler.step() |
@@ -804,7 +837,11 @@ def main(): | |||
804 | if global_step % args.sample_frequency == 0: | 837 | if global_step % args.sample_frequency == 0: |
805 | sample_checkpoint = True | 838 | sample_checkpoint = True |
806 | 839 | ||
807 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 840 | logs = { |
841 | "train/loss": loss, | ||
842 | "lr/unet": lr_scheduler.get_last_lr()[0], | ||
843 | "lr/text": lr_scheduler.get_last_lr()[1] | ||
844 | } | ||
808 | if args.use_ema: | 845 | if args.use_ema: |
809 | logs["ema_decay"] = ema_unet.decay | 846 | logs["ema_decay"] = ema_unet.decay |
810 | 847 | ||
@@ -820,6 +857,7 @@ def main(): | |||
820 | accelerator.wait_for_everyone() | 857 | accelerator.wait_for_everyone() |
821 | 858 | ||
822 | unet.eval() | 859 | unet.eval() |
860 | text_encoder.eval() | ||
823 | val_loss = 0.0 | 861 | val_loss = 0.0 |
824 | 862 | ||
825 | for step, batch in enumerate(val_dataloader): | 863 | for step, batch in enumerate(val_dataloader): |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 06ff45b..413abe3 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -125,7 +125,7 @@ def parse_args(): | |||
125 | parser.add_argument( | 125 | parser.add_argument( |
126 | "--max_train_steps", | 126 | "--max_train_steps", |
127 | type=int, | 127 | type=int, |
128 | default=2400, | 128 | default=4700, |
129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
130 | ) | 130 | ) |
131 | parser.add_argument( | 131 | parser.add_argument( |
@@ -142,13 +142,13 @@ def parse_args(): | |||
142 | parser.add_argument( | 142 | parser.add_argument( |
143 | "--learning_rate_unet", | 143 | "--learning_rate_unet", |
144 | type=float, | 144 | type=float, |
145 | default=5e-6, | 145 | default=2e-6, |
146 | help="Initial learning rate (after the potential warmup period) to use.", | 146 | help="Initial learning rate (after the potential warmup period) to use.", |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--learning_rate_text", | 149 | "--learning_rate_text", |
150 | type=float, | 150 | type=float, |
151 | default=5e-6, | 151 | default=2e-6, |
152 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
153 | ) | 153 | ) |
154 | parser.add_argument( | 154 | parser.add_argument( |
@@ -578,6 +578,7 @@ def main(): | |||
578 | 578 | ||
579 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
580 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
581 | text_encoder.gradient_checkpointing_enable() | ||
581 | 582 | ||
582 | # slice_size = unet.config.attention_head_dim // 2 | 583 | # slice_size = unet.config.attention_head_dim // 2 |
583 | # unet.set_attention_slice(slice_size) | 584 | # unet.set_attention_slice(slice_size) |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index c1e3340..259ac44 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -1,4 +1,4 @@ | |||
1 | from typing import List, Optional, Union | 1 | from typing import List, Union |
2 | 2 | ||
3 | import torch | 3 | import torch |
4 | 4 | ||