summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-19 12:19:23 +0200
committerVolpeon <git@volpeon.ink>2022-10-19 12:19:23 +0200
commitb4a00845721fbc95819ad888dfd7c24013bbf4d0 (patch)
treedf5888d0a52077d7fb1035939fb2b2e8547a0655 /dreambooth.py
parentAdapted other scripts for new prompt processing (diff)
downloadtextual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.gz
textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.tar.bz2
textual-inversion-diff-b4a00845721fbc95819ad888dfd7c24013bbf4d0.zip
Updated Dreambooth training
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py104
1 files changed, 71 insertions, 33 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 9786e0f..d1cf535 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -1,4 +1,5 @@
1import argparse 1import argparse
2import itertools
2import math 3import math
3import os 4import os
4import datetime 5import datetime
@@ -113,7 +114,7 @@ def parse_args():
113 parser.add_argument( 114 parser.add_argument(
114 "--max_train_steps", 115 "--max_train_steps",
115 type=int, 116 type=int,
116 default=1200, 117 default=3600,
117 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 118 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
118 ) 119 )
119 parser.add_argument( 120 parser.add_argument(
@@ -128,9 +129,15 @@ def parse_args():
128 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 129 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
129 ) 130 )
130 parser.add_argument( 131 parser.add_argument(
131 "--learning_rate", 132 "--learning_rate_unet",
132 type=float, 133 type=float,
133 default=5e-5, 134 default=3e-6,
135 help="Initial learning rate (after the potential warmup period) to use.",
136 )
137 parser.add_argument(
138 "--learning_rate_text",
139 type=float,
140 default=3e-6,
134 help="Initial learning rate (after the potential warmup period) to use.", 141 help="Initial learning rate (after the potential warmup period) to use.",
135 ) 142 )
136 parser.add_argument( 143 parser.add_argument(
@@ -358,12 +365,14 @@ class Checkpointer:
358 def save_model(self): 365 def save_model(self):
359 print("Saving model...") 366 print("Saving model...")
360 367
361 unwrapped = self.accelerator.unwrap_model( 368 unwrapped_unet = self.accelerator.unwrap_model(
362 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 369 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
370 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
371
363 pipeline = VlpnStableDiffusion( 372 pipeline = VlpnStableDiffusion(
364 text_encoder=self.text_encoder, 373 text_encoder=unwrapped_text_encoder,
365 vae=self.vae, 374 vae=self.vae,
366 unet=unwrapped, 375 unet=unwrapped_unet,
367 tokenizer=self.tokenizer, 376 tokenizer=self.tokenizer,
368 scheduler=PNDMScheduler( 377 scheduler=PNDMScheduler(
369 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 378 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
@@ -371,7 +380,8 @@ class Checkpointer:
371 ) 380 )
372 pipeline.save_pretrained(self.output_dir.joinpath("model")) 381 pipeline.save_pretrained(self.output_dir.joinpath("model"))
373 382
374 del unwrapped 383 del unwrapped_unet
384 del unwrapped_text_encoder
375 del pipeline 385 del pipeline
376 386
377 if torch.cuda.is_available(): 387 if torch.cuda.is_available():
@@ -381,16 +391,18 @@ class Checkpointer:
381 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 391 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
382 samples_path = Path(self.output_dir).joinpath("samples") 392 samples_path = Path(self.output_dir).joinpath("samples")
383 393
384 unwrapped = self.accelerator.unwrap_model( 394 unwrapped_unet = self.accelerator.unwrap_model(
385 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 395 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
396 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
397
386 scheduler = EulerAScheduler( 398 scheduler = EulerAScheduler(
387 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 399 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
388 ) 400 )
389 401
390 pipeline = VlpnStableDiffusion( 402 pipeline = VlpnStableDiffusion(
391 text_encoder=self.text_encoder, 403 text_encoder=unwrapped_text_encoder,
392 vae=self.vae, 404 vae=self.vae,
393 unet=unwrapped, 405 unet=unwrapped_unet,
394 tokenizer=self.tokenizer, 406 tokenizer=self.tokenizer,
395 scheduler=scheduler, 407 scheduler=scheduler,
396 ).to(self.accelerator.device) 408 ).to(self.accelerator.device)
@@ -416,9 +428,16 @@ class Checkpointer:
416 428
417 for i in range(self.sample_batches): 429 for i in range(self.sample_batches):
418 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 430 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
419 prompt = [prompt.format(self.instance_identifier) 431 prompt = [
420 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 432 prompt.format(self.instance_identifier)
421 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 433 for batch in batches
434 for prompt in batch["prompts"]
435 ][:self.sample_batch_size]
436 nprompt = [
437 prompt
438 for batch in batches
439 for prompt in batch["nprompts"]
440 ][:self.sample_batch_size]
422 441
423 samples = pipeline( 442 samples = pipeline(
424 prompt=prompt, 443 prompt=prompt,
@@ -443,7 +462,8 @@ class Checkpointer:
443 del all_samples 462 del all_samples
444 del image_grid 463 del image_grid
445 464
446 del unwrapped 465 del unwrapped_unet
466 del unwrapped_text_encoder
447 del scheduler 467 del scheduler
448 del pipeline 468 del pipeline
449 del generator 469 del generator
@@ -482,8 +502,7 @@ def main():
482 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 502 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
483 503
484 # Load models and create wrapper for stable diffusion 504 # Load models and create wrapper for stable diffusion
485 text_encoder = CLIPTextModel.from_pretrained( 505 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
486 args.pretrained_model_name_or_path, subfolder='text_encoder')
487 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 506 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
488 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') 507 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
489 508
@@ -499,17 +518,21 @@ def main():
499 518
500 if args.gradient_checkpointing: 519 if args.gradient_checkpointing:
501 unet.enable_gradient_checkpointing() 520 unet.enable_gradient_checkpointing()
521 text_encoder.gradient_checkpointing_enable()
502 522
503 # slice_size = unet.config.attention_head_dim // 2 523 # slice_size = unet.config.attention_head_dim // 2
504 # unet.set_attention_slice(slice_size) 524 # unet.set_attention_slice(slice_size)
505 525
506 # Freeze text_encoder and vae 526 # Freeze text_encoder and vae
507 freeze_params(vae.parameters()) 527 freeze_params(vae.parameters())
508 freeze_params(text_encoder.parameters())
509 528
510 if args.scale_lr: 529 if args.scale_lr:
511 args.learning_rate = ( 530 args.learning_rate_unet = (
512 args.learning_rate * args.gradient_accumulation_steps * 531 args.learning_rate_unet * args.gradient_accumulation_steps *
532 args.train_batch_size * accelerator.num_processes
533 )
534 args.learning_rate_text = (
535 args.learning_rate_text * args.gradient_accumulation_steps *
513 args.train_batch_size * accelerator.num_processes 536 args.train_batch_size * accelerator.num_processes
514 ) 537 )
515 538
@@ -526,8 +549,16 @@ def main():
526 549
527 # Initialize the optimizer 550 # Initialize the optimizer
528 optimizer = optimizer_class( 551 optimizer = optimizer_class(
529 unet.parameters(), # only optimize unet 552 [
530 lr=args.learning_rate, 553 {
554 'params': unet.parameters(),
555 'lr': args.learning_rate_unet,
556 },
557 {
558 'params': text_encoder.parameters(),
559 'lr': args.learning_rate_text,
560 }
561 ],
531 betas=(args.adam_beta1, args.adam_beta2), 562 betas=(args.adam_beta1, args.adam_beta2),
532 weight_decay=args.adam_weight_decay, 563 weight_decay=args.adam_weight_decay,
533 eps=args.adam_epsilon, 564 eps=args.adam_epsilon,
@@ -592,8 +623,10 @@ def main():
592 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] 623 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
593 624
594 if len(missing_data) != 0: 625 if len(missing_data) != 0:
595 batched_data = [missing_data[i:i+args.sample_batch_size] 626 batched_data = [
596 for i in range(0, len(missing_data), args.sample_batch_size)] 627 missing_data[i:i+args.sample_batch_size]
628 for i in range(0, len(missing_data), args.sample_batch_size)
629 ]
597 630
598 scheduler = EulerAScheduler( 631 scheduler = EulerAScheduler(
599 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 632 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
@@ -610,9 +643,9 @@ def main():
610 643
611 with torch.inference_mode(): 644 with torch.inference_mode():
612 for batch in batched_data: 645 for batch in batched_data:
613 image_name = [p.class_image_path for p in batch] 646 image_name = [item.class_image_path for item in batch]
614 prompt = [p.prompt.format(args.class_identifier) for p in batch] 647 prompt = [item.prompt.format(args.class_identifier) for item in batch]
615 nprompt = [p.nprompt for p in batch] 648 nprompt = [item.nprompt for item in batch]
616 649
617 images = pipeline( 650 images = pipeline(
618 prompt=prompt, 651 prompt=prompt,
@@ -655,16 +688,14 @@ def main():
655 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 688 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
656 ) 689 )
657 690
658 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 691 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
659 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 692 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
660 ) 693 )
661 694
662 # Move text_encoder and vae to device 695 # Move text_encoder and vae to device
663 text_encoder.to(accelerator.device, dtype=weight_dtype)
664 vae.to(accelerator.device, dtype=weight_dtype) 696 vae.to(accelerator.device, dtype=weight_dtype)
665 697
666 # Keep text_encoder and vae in eval mode as we don't train these 698 # Keep text_encoder and vae in eval mode as we don't train these
667 text_encoder.eval()
668 vae.eval() 699 vae.eval()
669 700
670 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 701 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -736,12 +767,13 @@ def main():
736 local_progress_bar.reset() 767 local_progress_bar.reset()
737 768
738 unet.train() 769 unet.train()
770 text_encoder.train()
739 train_loss = 0.0 771 train_loss = 0.0
740 772
741 sample_checkpoint = False 773 sample_checkpoint = False
742 774
743 for step, batch in enumerate(train_dataloader): 775 for step, batch in enumerate(train_dataloader):
744 with accelerator.accumulate(unet): 776 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
745 # Convert images to latent space 777 # Convert images to latent space
746 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 778 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
747 latents = latents * 0.18215 779 latents = latents * 0.18215
@@ -782,7 +814,8 @@ def main():
782 814
783 accelerator.backward(loss) 815 accelerator.backward(loss)
784 if accelerator.sync_gradients: 816 if accelerator.sync_gradients:
785 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 817 accelerator.clip_grad_norm_(itertools.chain(
818 unet.parameters(), text_encoder.parameters()), args.max_grad_norm)
786 optimizer.step() 819 optimizer.step()
787 if not accelerator.optimizer_step_was_skipped: 820 if not accelerator.optimizer_step_was_skipped:
788 lr_scheduler.step() 821 lr_scheduler.step()
@@ -804,7 +837,11 @@ def main():
804 if global_step % args.sample_frequency == 0: 837 if global_step % args.sample_frequency == 0:
805 sample_checkpoint = True 838 sample_checkpoint = True
806 839
807 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 840 logs = {
841 "train/loss": loss,
842 "lr/unet": lr_scheduler.get_last_lr()[0],
843 "lr/text": lr_scheduler.get_last_lr()[1]
844 }
808 if args.use_ema: 845 if args.use_ema:
809 logs["ema_decay"] = ema_unet.decay 846 logs["ema_decay"] = ema_unet.decay
810 847
@@ -820,6 +857,7 @@ def main():
820 accelerator.wait_for_everyone() 857 accelerator.wait_for_everyone()
821 858
822 unet.eval() 859 unet.eval()
860 text_encoder.eval()
823 val_loss = 0.0 861 val_loss = 0.0
824 862
825 for step, batch in enumerate(val_dataloader): 863 for step, batch in enumerate(val_dataloader):