diff options
-rw-r--r-- | dreambooth.py | 14 | ||||
-rw-r--r-- | environment.yaml | 5 | ||||
-rw-r--r-- | infer.py | 2 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 36 | ||||
-rw-r--r-- | textual_inversion.py | 14 | ||||
-rw-r--r-- | training/optimization.py | 42 |
6 files changed, 102 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index c0caf03..8c4bf50 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -26,6 +26,7 @@ from slugify import slugify | |||
26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | 26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
29 | from training.optimization import get_one_cycle_schedule | ||
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
30 | 31 | ||
31 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
@@ -178,10 +179,10 @@ def parse_args(): | |||
178 | parser.add_argument( | 179 | parser.add_argument( |
179 | "--lr_scheduler", | 180 | "--lr_scheduler", |
180 | type=str, | 181 | type=str, |
181 | default="cosine_with_restarts", | 182 | default="one_cycle", |
182 | help=( | 183 | help=( |
183 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 184 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
184 | ' "constant", "constant_with_warmup"]' | 185 | ' "constant", "constant_with_warmup", "one_cycle"]' |
185 | ), | 186 | ), |
186 | ) | 187 | ) |
187 | parser.add_argument( | 188 | parser.add_argument( |
@@ -585,6 +586,8 @@ def main(): | |||
585 | device=accelerator.device | 586 | device=accelerator.device |
586 | ) | 587 | ) |
587 | 588 | ||
589 | unet.set_use_memory_efficient_attention_xformers(True) | ||
590 | |||
588 | if args.gradient_checkpointing: | 591 | if args.gradient_checkpointing: |
589 | unet.enable_gradient_checkpointing() | 592 | unet.enable_gradient_checkpointing() |
590 | text_encoder.gradient_checkpointing_enable() | 593 | text_encoder.gradient_checkpointing_enable() |
@@ -784,7 +787,12 @@ def main(): | |||
784 | overrode_max_train_steps = True | 787 | overrode_max_train_steps = True |
785 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 788 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
786 | 789 | ||
787 | if args.lr_scheduler == "cosine_with_restarts": | 790 | if args.lr_scheduler == "one_cycle": |
791 | lr_scheduler = get_one_cycle_schedule( | ||
792 | optimizer=optimizer, | ||
793 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
794 | ) | ||
795 | elif args.lr_scheduler == "cosine_with_restarts": | ||
788 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 796 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
789 | optimizer=optimizer, | 797 | optimizer=optimizer, |
790 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 798 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
diff --git a/environment.yaml b/environment.yaml index de35645..7aa5312 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -1,6 +1,7 @@ | |||
1 | name: ldd | 1 | name: ldd |
2 | channels: | 2 | channels: |
3 | - pytorch | 3 | - pytorch |
4 | - xformers/label/dev | ||
4 | - defaults | 5 | - defaults |
5 | dependencies: | 6 | dependencies: |
6 | - cudatoolkit=11.3 | 7 | - cudatoolkit=11.3 |
@@ -10,13 +11,14 @@ dependencies: | |||
10 | - pytorch=1.12.1 | 11 | - pytorch=1.12.1 |
11 | - torchvision=0.13.1 | 12 | - torchvision=0.13.1 |
12 | - pandas=1.4.3 | 13 | - pandas=1.4.3 |
14 | - xformers=0.0.14.dev315 | ||
13 | - pip: | 15 | - pip: |
14 | - -e . | 16 | - -e . |
15 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | 17 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers |
16 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip | 18 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip |
17 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion | 19 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion |
18 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion | 20 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion |
19 | - -e git+https://github.com/ShivamShrirao/diffusers#egg=diffusers | 21 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
20 | - accelerate==0.12.0 | 22 | - accelerate==0.12.0 |
21 | - albumentations==1.1.0 | 23 | - albumentations==1.1.0 |
22 | - bitsandbytes==0.34.0 | 24 | - bitsandbytes==0.34.0 |
@@ -34,4 +36,3 @@ dependencies: | |||
34 | - torchmetrics==0.9.3 | 36 | - torchmetrics==0.9.3 |
35 | - transformers==4.23.1 | 37 | - transformers==4.23.1 |
36 | - triton==2.0.0.dev20220924 | 38 | - triton==2.0.0.dev20220924 |
37 | - xformers==0.0.13 | ||
@@ -234,7 +234,7 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
234 | tokenizer=tokenizer, | 234 | tokenizer=tokenizer, |
235 | scheduler=scheduler, | 235 | scheduler=scheduler, |
236 | ) | 236 | ) |
237 | pipeline.aesthetic_gradient_iters = 20 | 237 | pipeline.enable_xformers_memory_efficient_attention() |
238 | pipeline.to("cuda") | 238 | pipeline.to("cuda") |
239 | 239 | ||
240 | print("Pipeline loaded.") | 240 | print("Pipeline loaded.") |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cd5ae7e..36942f0 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -7,6 +7,7 @@ import torch | |||
7 | import PIL | 7 | import PIL |
8 | 8 | ||
9 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
10 | from diffusers.utils import is_accelerate_available | ||
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 11 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 13 | from diffusers.utils import logging |
@@ -61,13 +62,27 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
61 | scheduler=scheduler, | 62 | scheduler=scheduler, |
62 | ) | 63 | ) |
63 | 64 | ||
65 | def enable_xformers_memory_efficient_attention(self): | ||
66 | r""" | ||
67 | Enable memory efficient attention as implemented in xformers. | ||
68 | When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference | ||
69 | time. Speed up at training time is not guaranteed. | ||
70 | Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention | ||
71 | is used. | ||
72 | """ | ||
73 | self.unet.set_use_memory_efficient_attention_xformers(True) | ||
74 | |||
75 | def disable_xformers_memory_efficient_attention(self): | ||
76 | r""" | ||
77 | Disable memory efficient attention as implemented in xformers. | ||
78 | """ | ||
79 | self.unet.set_use_memory_efficient_attention_xformers(False) | ||
80 | |||
64 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 81 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
65 | r""" | 82 | r""" |
66 | Enable sliced attention computation. | 83 | Enable sliced attention computation. |
67 | |||
68 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention | 84 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention |
69 | in several steps. This is useful to save some memory in exchange for a small speed decrease. | 85 | in several steps. This is useful to save some memory in exchange for a small speed decrease. |
70 | |||
71 | Args: | 86 | Args: |
72 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): | 87 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): |
73 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If | 88 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If |
@@ -88,6 +103,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
88 | # set slice_size = `None` to disable `attention slicing` | 103 | # set slice_size = `None` to disable `attention slicing` |
89 | self.enable_attention_slicing(None) | 104 | self.enable_attention_slicing(None) |
90 | 105 | ||
106 | def enable_sequential_cpu_offload(self): | ||
107 | r""" | ||
108 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, | ||
109 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a | ||
110 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. | ||
111 | """ | ||
112 | if is_accelerate_available(): | ||
113 | from accelerate import cpu_offload | ||
114 | else: | ||
115 | raise ImportError("Please install accelerate via `pip install accelerate`") | ||
116 | |||
117 | device = torch.device("cuda") | ||
118 | |||
119 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: | ||
120 | if cpu_offloaded_model is not None: | ||
121 | cpu_offload(cpu_offloaded_model, device) | ||
122 | |||
91 | @torch.no_grad() | 123 | @torch.no_grad() |
92 | def __call__( | 124 | def __call__( |
93 | self, | 125 | self, |
diff --git a/textual_inversion.py b/textual_inversion.py index 115f3aa..578c054 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -25,6 +25,7 @@ from slugify import slugify | |||
25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | 25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
28 | from training.optimization import get_one_cycle_schedule | ||
28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
29 | 30 | ||
30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -162,10 +163,10 @@ def parse_args(): | |||
162 | parser.add_argument( | 163 | parser.add_argument( |
163 | "--lr_scheduler", | 164 | "--lr_scheduler", |
164 | type=str, | 165 | type=str, |
165 | default="cosine_with_restarts", | 166 | default="one_cycle", |
166 | help=( | 167 | help=( |
167 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 168 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
168 | ' "constant", "constant_with_warmup"]' | 169 | ' "constant", "constant_with_warmup", "one_cycle"]' |
169 | ), | 170 | ), |
170 | ) | 171 | ) |
171 | parser.add_argument( | 172 | parser.add_argument( |
@@ -535,6 +536,8 @@ def main(): | |||
535 | 536 | ||
536 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 537 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
537 | 538 | ||
539 | unet.set_use_memory_efficient_attention_xformers(True) | ||
540 | |||
538 | if args.gradient_checkpointing: | 541 | if args.gradient_checkpointing: |
539 | text_encoder.gradient_checkpointing_enable() | 542 | text_encoder.gradient_checkpointing_enable() |
540 | 543 | ||
@@ -693,7 +696,12 @@ def main(): | |||
693 | overrode_max_train_steps = True | 696 | overrode_max_train_steps = True |
694 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 697 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
695 | 698 | ||
696 | if args.lr_scheduler == "cosine_with_restarts": | 699 | if args.lr_scheduler == "one_cycle": |
700 | lr_scheduler = get_one_cycle_schedule( | ||
701 | optimizer=optimizer, | ||
702 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
703 | ) | ||
704 | elif args.lr_scheduler == "cosine_with_restarts": | ||
697 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 705 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
698 | optimizer=optimizer, | 706 | optimizer=optimizer, |
699 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 707 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
diff --git a/training/optimization.py b/training/optimization.py new file mode 100644 index 0000000..012beed --- /dev/null +++ b/training/optimization.py | |||
@@ -0,0 +1,42 @@ | |||
1 | import math | ||
2 | from torch.optim.lr_scheduler import LambdaLR | ||
3 | |||
4 | from diffusers.utils import logging | ||
5 | |||
6 | logger = logging.get_logger(__name__) | ||
7 | |||
8 | |||
9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.42, last_epoch=-1): | ||
10 | """ | ||
11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | ||
12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | ||
13 | Args: | ||
14 | optimizer ([`~torch.optim.Optimizer`]): | ||
15 | The optimizer for which to schedule the learning rate. | ||
16 | num_training_steps (`int`): | ||
17 | The total number of training steps. | ||
18 | last_epoch (`int`, *optional*, defaults to -1): | ||
19 | The index of the last epoch when resuming training. | ||
20 | Return: | ||
21 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
22 | """ | ||
23 | |||
24 | def lr_lambda(current_step: int): | ||
25 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | ||
26 | |||
27 | if current_step < thresh_up: | ||
28 | return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) | ||
29 | |||
30 | if annealing == "linear": | ||
31 | thresh_down = thresh_up * 2 | ||
32 | |||
33 | if current_step < thresh_down: | ||
34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | ||
35 | |||
36 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))) * min_lr | ||
37 | else: | ||
38 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | ||
39 | |||
40 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | ||
41 | |||
42 | return LambdaLR(optimizer, lr_lambda, last_epoch) | ||