summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py14
-rw-r--r--environment.yaml5
-rw-r--r--infer.py2
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py36
-rw-r--r--textual_inversion.py14
-rw-r--r--training/optimization.py42
6 files changed, 102 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py
index c0caf03..8c4bf50 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -26,6 +26,7 @@ from slugify import slugify
26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule
29from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
30 31
31logger = get_logger(__name__) 32logger = get_logger(__name__)
@@ -178,10 +179,10 @@ def parse_args():
178 parser.add_argument( 179 parser.add_argument(
179 "--lr_scheduler", 180 "--lr_scheduler",
180 type=str, 181 type=str,
181 default="cosine_with_restarts", 182 default="one_cycle",
182 help=( 183 help=(
183 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 184 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
184 ' "constant", "constant_with_warmup"]' 185 ' "constant", "constant_with_warmup", "one_cycle"]'
185 ), 186 ),
186 ) 187 )
187 parser.add_argument( 188 parser.add_argument(
@@ -585,6 +586,8 @@ def main():
585 device=accelerator.device 586 device=accelerator.device
586 ) 587 )
587 588
589 unet.set_use_memory_efficient_attention_xformers(True)
590
588 if args.gradient_checkpointing: 591 if args.gradient_checkpointing:
589 unet.enable_gradient_checkpointing() 592 unet.enable_gradient_checkpointing()
590 text_encoder.gradient_checkpointing_enable() 593 text_encoder.gradient_checkpointing_enable()
@@ -784,7 +787,12 @@ def main():
784 overrode_max_train_steps = True 787 overrode_max_train_steps = True
785 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 788 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
786 789
787 if args.lr_scheduler == "cosine_with_restarts": 790 if args.lr_scheduler == "one_cycle":
791 lr_scheduler = get_one_cycle_schedule(
792 optimizer=optimizer,
793 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
794 )
795 elif args.lr_scheduler == "cosine_with_restarts":
788 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 796 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
789 optimizer=optimizer, 797 optimizer=optimizer,
790 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 798 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
diff --git a/environment.yaml b/environment.yaml
index de35645..7aa5312 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -1,6 +1,7 @@
1name: ldd 1name: ldd
2channels: 2channels:
3 - pytorch 3 - pytorch
4 - xformers/label/dev
4 - defaults 5 - defaults
5dependencies: 6dependencies:
6 - cudatoolkit=11.3 7 - cudatoolkit=11.3
@@ -10,13 +11,14 @@ dependencies:
10 - pytorch=1.12.1 11 - pytorch=1.12.1
11 - torchvision=0.13.1 12 - torchvision=0.13.1
12 - pandas=1.4.3 13 - pandas=1.4.3
14 - xformers=0.0.14.dev315
13 - pip: 15 - pip:
14 - -e . 16 - -e .
15 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 17 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
16 - -e git+https://github.com/openai/CLIP.git@main#egg=clip 18 - -e git+https://github.com/openai/CLIP.git@main#egg=clip
17 - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion 19 - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
18 - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion 20 - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion
19 - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers 21 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
20 - accelerate==0.12.0 22 - accelerate==0.12.0
21 - albumentations==1.1.0 23 - albumentations==1.1.0
22 - bitsandbytes==0.34.0 24 - bitsandbytes==0.34.0
@@ -34,4 +36,3 @@ dependencies:
34 - torchmetrics==0.9.3 36 - torchmetrics==0.9.3
35 - transformers==4.23.1 37 - transformers==4.23.1
36 - triton==2.0.0.dev20220924 38 - triton==2.0.0.dev20220924
37 - xformers==0.0.13
diff --git a/infer.py b/infer.py
index ac05955..9bc9efe 100644
--- a/infer.py
+++ b/infer.py
@@ -234,7 +234,7 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
234 tokenizer=tokenizer, 234 tokenizer=tokenizer,
235 scheduler=scheduler, 235 scheduler=scheduler,
236 ) 236 )
237 pipeline.aesthetic_gradient_iters = 20 237 pipeline.enable_xformers_memory_efficient_attention()
238 pipeline.to("cuda") 238 pipeline.to("cuda")
239 239
240 print("Pipeline loaded.") 240 print("Pipeline loaded.")
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index cd5ae7e..36942f0 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -7,6 +7,7 @@ import torch
7import PIL 7import PIL
8 8
9from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
10from diffusers.utils import is_accelerate_available
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 11from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 12from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 13from diffusers.utils import logging
@@ -61,13 +62,27 @@ class VlpnStableDiffusion(DiffusionPipeline):
61 scheduler=scheduler, 62 scheduler=scheduler,
62 ) 63 )
63 64
65 def enable_xformers_memory_efficient_attention(self):
66 r"""
67 Enable memory efficient attention as implemented in xformers.
68 When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
69 time. Speed up at training time is not guaranteed.
70 Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
71 is used.
72 """
73 self.unet.set_use_memory_efficient_attention_xformers(True)
74
75 def disable_xformers_memory_efficient_attention(self):
76 r"""
77 Disable memory efficient attention as implemented in xformers.
78 """
79 self.unet.set_use_memory_efficient_attention_xformers(False)
80
64 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 81 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
65 r""" 82 r"""
66 Enable sliced attention computation. 83 Enable sliced attention computation.
67
68 When this option is enabled, the attention module will split the input tensor in slices, to compute attention 84 When this option is enabled, the attention module will split the input tensor in slices, to compute attention
69 in several steps. This is useful to save some memory in exchange for a small speed decrease. 85 in several steps. This is useful to save some memory in exchange for a small speed decrease.
70
71 Args: 86 Args:
72 slice_size (`str` or `int`, *optional*, defaults to `"auto"`): 87 slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
73 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 88 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
@@ -88,6 +103,23 @@ class VlpnStableDiffusion(DiffusionPipeline):
88 # set slice_size = `None` to disable `attention slicing` 103 # set slice_size = `None` to disable `attention slicing`
89 self.enable_attention_slicing(None) 104 self.enable_attention_slicing(None)
90 105
106 def enable_sequential_cpu_offload(self):
107 r"""
108 Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
109 text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
110 `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
111 """
112 if is_accelerate_available():
113 from accelerate import cpu_offload
114 else:
115 raise ImportError("Please install accelerate via `pip install accelerate`")
116
117 device = torch.device("cuda")
118
119 for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
120 if cpu_offloaded_model is not None:
121 cpu_offload(cpu_offloaded_model, device)
122
91 @torch.no_grad() 123 @torch.no_grad()
92 def __call__( 124 def __call__(
93 self, 125 self,
diff --git a/textual_inversion.py b/textual_inversion.py
index 115f3aa..578c054 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,6 +25,7 @@ from slugify import slugify
25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule
28from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
29 30
30logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -162,10 +163,10 @@ def parse_args():
162 parser.add_argument( 163 parser.add_argument(
163 "--lr_scheduler", 164 "--lr_scheduler",
164 type=str, 165 type=str,
165 default="cosine_with_restarts", 166 default="one_cycle",
166 help=( 167 help=(
167 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 168 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
168 ' "constant", "constant_with_warmup"]' 169 ' "constant", "constant_with_warmup", "one_cycle"]'
169 ), 170 ),
170 ) 171 )
171 parser.add_argument( 172 parser.add_argument(
@@ -535,6 +536,8 @@ def main():
535 536
536 prompt_processor = PromptProcessor(tokenizer, text_encoder) 537 prompt_processor = PromptProcessor(tokenizer, text_encoder)
537 538
539 unet.set_use_memory_efficient_attention_xformers(True)
540
538 if args.gradient_checkpointing: 541 if args.gradient_checkpointing:
539 text_encoder.gradient_checkpointing_enable() 542 text_encoder.gradient_checkpointing_enable()
540 543
@@ -693,7 +696,12 @@ def main():
693 overrode_max_train_steps = True 696 overrode_max_train_steps = True
694 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 697 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
695 698
696 if args.lr_scheduler == "cosine_with_restarts": 699 if args.lr_scheduler == "one_cycle":
700 lr_scheduler = get_one_cycle_schedule(
701 optimizer=optimizer,
702 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
703 )
704 elif args.lr_scheduler == "cosine_with_restarts":
697 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 705 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
698 optimizer=optimizer, 706 optimizer=optimizer,
699 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 707 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
diff --git a/training/optimization.py b/training/optimization.py
new file mode 100644
index 0000000..012beed
--- /dev/null
+++ b/training/optimization.py
@@ -0,0 +1,42 @@
1import math
2from torch.optim.lr_scheduler import LambdaLR
3
4from diffusers.utils import logging
5
6logger = logging.get_logger(__name__)
7
8
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.42, last_epoch=-1):
10 """
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
13 Args:
14 optimizer ([`~torch.optim.Optimizer`]):
15 The optimizer for which to schedule the learning rate.
16 num_training_steps (`int`):
17 The total number of training steps.
18 last_epoch (`int`, *optional*, defaults to -1):
19 The index of the last epoch when resuming training.
20 Return:
21 `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
22 """
23
24 def lr_lambda(current_step: int):
25 thresh_up = int(num_training_steps * min(mid_point, 0.5))
26
27 if current_step < thresh_up:
28 return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr)
29
30 if annealing == "linear":
31 thresh_down = thresh_up * 2
32
33 if current_step < thresh_down:
34 return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr)
35
36 return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))) * min_lr
37 else:
38 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
39
40 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
41
42 return LambdaLR(optimizer, lr_lambda, last_epoch)