diff options
author | Volpeon <git@volpeon.ink> | 2022-10-15 18:42:27 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-15 18:42:27 +0200 |
commit | fcbc11be99c011ab1003451ef72c95ca587902d8 (patch) | |
tree | 8a8416e2777874addd05fa2f59896a31f044f1fc | |
parent | Removed aesthetic gradients; training improvements (diff) | |
download | textual-inversion-diff-fcbc11be99c011ab1003451ef72c95ca587902d8.tar.gz textual-inversion-diff-fcbc11be99c011ab1003451ef72c95ca587902d8.tar.bz2 textual-inversion-diff-fcbc11be99c011ab1003451ef72c95ca587902d8.zip |
Update
-rw-r--r-- | dreambooth_plus.py | 34 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 8 | ||||
-rw-r--r-- | textual_inversion.py | 33 |
3 files changed, 49 insertions, 26 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index b5ec2fc..eeee424 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -16,7 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from diffusers.training_utils import EMAModel | 20 | from diffusers.training_utils import EMAModel |
21 | from PIL import Image | 21 | from PIL import Image |
22 | from tqdm.auto import tqdm | 22 | from tqdm.auto import tqdm |
@@ -118,7 +118,7 @@ def parse_args(): | |||
118 | parser.add_argument( | 118 | parser.add_argument( |
119 | "--max_train_steps", | 119 | "--max_train_steps", |
120 | type=int, | 120 | type=int, |
121 | default=2300, | 121 | default=1300, |
122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
123 | ) | 123 | ) |
124 | parser.add_argument( | 124 | parser.add_argument( |
@@ -317,6 +317,13 @@ def parse_args(): | |||
317 | return args | 317 | return args |
318 | 318 | ||
319 | 319 | ||
320 | def save_args(basepath: Path, args, extra={}): | ||
321 | info = {"args": vars(args)} | ||
322 | info["args"].update(extra) | ||
323 | with open(basepath.joinpath("args.json"), "w") as f: | ||
324 | json.dump(info, f, indent=4) | ||
325 | |||
326 | |||
320 | def freeze_params(params): | 327 | def freeze_params(params): |
321 | for param in params: | 328 | for param in params: |
322 | param.requires_grad = False | 329 | param.requires_grad = False |
@@ -503,6 +510,8 @@ def main(): | |||
503 | 510 | ||
504 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 511 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
505 | 512 | ||
513 | save_args(basepath, args) | ||
514 | |||
506 | # If passed along, set the training seed now. | 515 | # If passed along, set the training seed now. |
507 | if args.seed is not None: | 516 | if args.seed is not None: |
508 | set_seed(args.seed) | 517 | set_seed(args.seed) |
@@ -706,12 +715,21 @@ def main(): | |||
706 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 715 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
707 | overrode_max_train_steps = True | 716 | overrode_max_train_steps = True |
708 | 717 | ||
709 | lr_scheduler = get_scheduler( | 718 | if args.lr_scheduler == "cosine_with_restarts": |
710 | args.lr_scheduler, | 719 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
711 | optimizer=optimizer, | 720 | args.lr_scheduler, |
712 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 721 | optimizer=optimizer, |
713 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 722 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
714 | ) | 723 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
724 | num_cycles=num_update_steps_per_epoch, | ||
725 | ) | ||
726 | else: | ||
727 | lr_scheduler = get_scheduler( | ||
728 | args.lr_scheduler, | ||
729 | optimizer=optimizer, | ||
730 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
731 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
732 | ) | ||
715 | 733 | ||
716 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 734 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
717 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 735 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3e41f86..2656b28 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -4,7 +4,6 @@ from typing import List, Optional, Union | |||
4 | 4 | ||
5 | import numpy as np | 5 | import numpy as np |
6 | import torch | 6 | import torch |
7 | import torch.optim as optim | ||
8 | import PIL | 7 | import PIL |
9 | 8 | ||
10 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
@@ -59,9 +58,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
59 | scheduler=scheduler, | 58 | scheduler=scheduler, |
60 | ) | 59 | ) |
61 | 60 | ||
62 | def get_text_embeddings(self, text_input_ids): | ||
63 | return self.text_encoder(text_input_ids)[0] | ||
64 | |||
65 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 61 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
66 | r""" | 62 | r""" |
67 | Enable sliced attention computation. | 63 | Enable sliced attention computation. |
@@ -199,7 +195,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
199 | ) | 195 | ) |
200 | print(f"Too many tokens: {removed_text}") | 196 | print(f"Too many tokens: {removed_text}") |
201 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 197 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
202 | text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) | 198 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] |
203 | 199 | ||
204 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 200 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
205 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 201 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
@@ -211,7 +207,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
211 | uncond_input = self.tokenizer( | 207 | uncond_input = self.tokenizer( |
212 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | 208 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" |
213 | ) | 209 | ) |
214 | uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) | 210 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] |
215 | 211 | ||
216 | # For classifier free guidance, we need to do two forward passes. | 212 | # For classifier free guidance, we need to do two forward passes. |
217 | # Here we concatenate the unconditional and text embeddings into a single batch | 213 | # Here we concatenate the unconditional and text embeddings into a single batch |
diff --git a/textual_inversion.py b/textual_inversion.py index 6627f1f..2109d13 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -16,7 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from PIL import Image | 20 | from PIL import Image |
21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
@@ -303,10 +303,10 @@ def freeze_params(params): | |||
303 | param.requires_grad = False | 303 | param.requires_grad = False |
304 | 304 | ||
305 | 305 | ||
306 | def save_resume_file(basepath, args, extra={}): | 306 | def save_args(basepath: Path, args, extra={}): |
307 | info = {"args": vars(args)} | 307 | info = {"args": vars(args)} |
308 | info["args"].update(extra) | 308 | info["args"].update(extra) |
309 | with open(f"{basepath}/resume.json", "w") as f: | 309 | with open(basepath.joinpath("args.json"), "w") as f: |
310 | json.dump(info, f, indent=4) | 310 | json.dump(info, f, indent=4) |
311 | 311 | ||
312 | 312 | ||
@@ -660,12 +660,21 @@ def main(): | |||
660 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 660 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
661 | overrode_max_train_steps = True | 661 | overrode_max_train_steps = True |
662 | 662 | ||
663 | lr_scheduler = get_scheduler( | 663 | if args.lr_scheduler == "cosine_with_restarts": |
664 | args.lr_scheduler, | 664 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
665 | optimizer=optimizer, | 665 | args.lr_scheduler, |
666 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 666 | optimizer=optimizer, |
667 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 667 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
668 | ) | 668 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
669 | num_cycles=num_update_steps_per_epoch, | ||
670 | ) | ||
671 | else: | ||
672 | lr_scheduler = get_scheduler( | ||
673 | args.lr_scheduler, | ||
674 | optimizer=optimizer, | ||
675 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
676 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
677 | ) | ||
669 | 678 | ||
670 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 679 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
671 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 680 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
@@ -827,7 +836,7 @@ def main(): | |||
827 | global_progress_bar.clear() | 836 | global_progress_bar.clear() |
828 | 837 | ||
829 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 838 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
830 | save_resume_file(basepath, args, { | 839 | save_args(basepath, args, { |
831 | "global_step": global_step + global_step_offset, | 840 | "global_step": global_step + global_step_offset, |
832 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 841 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
833 | }) | 842 | }) |
@@ -901,7 +910,7 @@ def main(): | |||
901 | if accelerator.is_main_process: | 910 | if accelerator.is_main_process: |
902 | print("Finished! Saving final checkpoint and resume state.") | 911 | print("Finished! Saving final checkpoint and resume state.") |
903 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 912 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
904 | save_resume_file(basepath, args, { | 913 | save_args(basepath, args, { |
905 | "global_step": global_step + global_step_offset, | 914 | "global_step": global_step + global_step_offset, |
906 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 915 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
907 | }) | 916 | }) |
@@ -911,7 +920,7 @@ def main(): | |||
911 | if accelerator.is_main_process: | 920 | if accelerator.is_main_process: |
912 | print("Interrupted, saving checkpoint and resume state...") | 921 | print("Interrupted, saving checkpoint and resume state...") |
913 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 922 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
914 | save_resume_file(basepath, args, { | 923 | save_args(basepath, args, { |
915 | "global_step": global_step + global_step_offset, | 924 | "global_step": global_step + global_step_offset, |
916 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 925 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
917 | }) | 926 | }) |