diff options
author | Volpeon <git@volpeon.ink> | 2022-11-03 17:56:08 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-03 17:56:08 +0100 |
commit | 036e33fbde6bad7c48bb6f6b3d695b7908535c64 (patch) | |
tree | 6b0df3deed03ccc8722763b51a9d7c2019be8e10 /dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.gz textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.bz2 textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.zip |
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py index c0caf03..8c4bf50 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -26,6 +26,7 @@ from slugify import slugify | |||
26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | 26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
29 | from training.optimization import get_one_cycle_schedule | ||
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
30 | 31 | ||
31 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
@@ -178,10 +179,10 @@ def parse_args(): | |||
178 | parser.add_argument( | 179 | parser.add_argument( |
179 | "--lr_scheduler", | 180 | "--lr_scheduler", |
180 | type=str, | 181 | type=str, |
181 | default="cosine_with_restarts", | 182 | default="one_cycle", |
182 | help=( | 183 | help=( |
183 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 184 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
184 | ' "constant", "constant_with_warmup"]' | 185 | ' "constant", "constant_with_warmup", "one_cycle"]' |
185 | ), | 186 | ), |
186 | ) | 187 | ) |
187 | parser.add_argument( | 188 | parser.add_argument( |
@@ -585,6 +586,8 @@ def main(): | |||
585 | device=accelerator.device | 586 | device=accelerator.device |
586 | ) | 587 | ) |
587 | 588 | ||
589 | unet.set_use_memory_efficient_attention_xformers(True) | ||
590 | |||
588 | if args.gradient_checkpointing: | 591 | if args.gradient_checkpointing: |
589 | unet.enable_gradient_checkpointing() | 592 | unet.enable_gradient_checkpointing() |
590 | text_encoder.gradient_checkpointing_enable() | 593 | text_encoder.gradient_checkpointing_enable() |
@@ -784,7 +787,12 @@ def main(): | |||
784 | overrode_max_train_steps = True | 787 | overrode_max_train_steps = True |
785 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 788 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
786 | 789 | ||
787 | if args.lr_scheduler == "cosine_with_restarts": | 790 | if args.lr_scheduler == "one_cycle": |
791 | lr_scheduler = get_one_cycle_schedule( | ||
792 | optimizer=optimizer, | ||
793 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
794 | ) | ||
795 | elif args.lr_scheduler == "cosine_with_restarts": | ||
788 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 796 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
789 | optimizer=optimizer, | 797 | optimizer=optimizer, |
790 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 798 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |