summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth_plus.py34
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py8
-rw-r--r--textual_inversion.py33
3 files changed, 49 insertions, 26 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index b5ec2fc..eeee424 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -16,7 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from diffusers.training_utils import EMAModel 20from diffusers.training_utils import EMAModel
21from PIL import Image 21from PIL import Image
22from tqdm.auto import tqdm 22from tqdm.auto import tqdm
@@ -118,7 +118,7 @@ def parse_args():
118 parser.add_argument( 118 parser.add_argument(
119 "--max_train_steps", 119 "--max_train_steps",
120 type=int, 120 type=int,
121 default=2300, 121 default=1300,
122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123 ) 123 )
124 parser.add_argument( 124 parser.add_argument(
@@ -317,6 +317,13 @@ def parse_args():
317 return args 317 return args
318 318
319 319
320def save_args(basepath: Path, args, extra={}):
321 info = {"args": vars(args)}
322 info["args"].update(extra)
323 with open(basepath.joinpath("args.json"), "w") as f:
324 json.dump(info, f, indent=4)
325
326
320def freeze_params(params): 327def freeze_params(params):
321 for param in params: 328 for param in params:
322 param.requires_grad = False 329 param.requires_grad = False
@@ -503,6 +510,8 @@ def main():
503 510
504 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 511 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
505 512
513 save_args(basepath, args)
514
506 # If passed along, set the training seed now. 515 # If passed along, set the training seed now.
507 if args.seed is not None: 516 if args.seed is not None:
508 set_seed(args.seed) 517 set_seed(args.seed)
@@ -706,12 +715,21 @@ def main():
706 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 715 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
707 overrode_max_train_steps = True 716 overrode_max_train_steps = True
708 717
709 lr_scheduler = get_scheduler( 718 if args.lr_scheduler == "cosine_with_restarts":
710 args.lr_scheduler, 719 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
711 optimizer=optimizer, 720 args.lr_scheduler,
712 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 721 optimizer=optimizer,
713 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 722 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
714 ) 723 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
724 num_cycles=num_update_steps_per_epoch,
725 )
726 else:
727 lr_scheduler = get_scheduler(
728 args.lr_scheduler,
729 optimizer=optimizer,
730 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
731 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
732 )
715 733
716 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 734 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
717 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 735 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 3e41f86..2656b28 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -4,7 +4,6 @@ from typing import List, Optional, Union
4 4
5import numpy as np 5import numpy as np
6import torch 6import torch
7import torch.optim as optim
8import PIL 7import PIL
9 8
10from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
@@ -59,9 +58,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
59 scheduler=scheduler, 58 scheduler=scheduler,
60 ) 59 )
61 60
62 def get_text_embeddings(self, text_input_ids):
63 return self.text_encoder(text_input_ids)[0]
64
65 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 61 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66 r""" 62 r"""
67 Enable sliced attention computation. 63 Enable sliced attention computation.
@@ -199,7 +195,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
199 ) 195 )
200 print(f"Too many tokens: {removed_text}") 196 print(f"Too many tokens: {removed_text}")
201 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 197 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
202 text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) 198 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
203 199
204 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 200 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
205 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 201 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -211,7 +207,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
211 uncond_input = self.tokenizer( 207 uncond_input = self.tokenizer(
212 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" 208 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
213 ) 209 )
214 uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) 210 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
215 211
216 # For classifier free guidance, we need to do two forward passes. 212 # For classifier free guidance, we need to do two forward passes.
217 # Here we concatenate the unconditional and text embeddings into a single batch 213 # Here we concatenate the unconditional and text embeddings into a single batch
diff --git a/textual_inversion.py b/textual_inversion.py
index 6627f1f..2109d13 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -16,7 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
@@ -303,10 +303,10 @@ def freeze_params(params):
303 param.requires_grad = False 303 param.requires_grad = False
304 304
305 305
306def save_resume_file(basepath, args, extra={}): 306def save_args(basepath: Path, args, extra={}):
307 info = {"args": vars(args)} 307 info = {"args": vars(args)}
308 info["args"].update(extra) 308 info["args"].update(extra)
309 with open(f"{basepath}/resume.json", "w") as f: 309 with open(basepath.joinpath("args.json"), "w") as f:
310 json.dump(info, f, indent=4) 310 json.dump(info, f, indent=4)
311 311
312 312
@@ -660,12 +660,21 @@ def main():
660 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 660 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
661 overrode_max_train_steps = True 661 overrode_max_train_steps = True
662 662
663 lr_scheduler = get_scheduler( 663 if args.lr_scheduler == "cosine_with_restarts":
664 args.lr_scheduler, 664 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
665 optimizer=optimizer, 665 args.lr_scheduler,
666 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 666 optimizer=optimizer,
667 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 667 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
668 ) 668 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
669 num_cycles=num_update_steps_per_epoch,
670 )
671 else:
672 lr_scheduler = get_scheduler(
673 args.lr_scheduler,
674 optimizer=optimizer,
675 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
676 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
677 )
669 678
670 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 679 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
671 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 680 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
@@ -827,7 +836,7 @@ def main():
827 global_progress_bar.clear() 836 global_progress_bar.clear()
828 837
829 checkpointer.checkpoint(global_step + global_step_offset, "training") 838 checkpointer.checkpoint(global_step + global_step_offset, "training")
830 save_resume_file(basepath, args, { 839 save_args(basepath, args, {
831 "global_step": global_step + global_step_offset, 840 "global_step": global_step + global_step_offset,
832 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 841 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
833 }) 842 })
@@ -901,7 +910,7 @@ def main():
901 if accelerator.is_main_process: 910 if accelerator.is_main_process:
902 print("Finished! Saving final checkpoint and resume state.") 911 print("Finished! Saving final checkpoint and resume state.")
903 checkpointer.checkpoint(global_step + global_step_offset, "end") 912 checkpointer.checkpoint(global_step + global_step_offset, "end")
904 save_resume_file(basepath, args, { 913 save_args(basepath, args, {
905 "global_step": global_step + global_step_offset, 914 "global_step": global_step + global_step_offset,
906 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 915 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
907 }) 916 })
@@ -911,7 +920,7 @@ def main():
911 if accelerator.is_main_process: 920 if accelerator.is_main_process:
912 print("Interrupted, saving checkpoint and resume state...") 921 print("Interrupted, saving checkpoint and resume state...")
913 checkpointer.checkpoint(global_step + global_step_offset, "end") 922 checkpointer.checkpoint(global_step + global_step_offset, "end")
914 save_resume_file(basepath, args, { 923 save_args(basepath, args, {
915 "global_step": global_step + global_step_offset, 924 "global_step": global_step + global_step_offset,
916 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 925 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
917 }) 926 })