diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-03 17:56:08 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-03 17:56:08 +0100 |
| commit | 036e33fbde6bad7c48bb6f6b3d695b7908535c64 (patch) | |
| tree | 6b0df3deed03ccc8722763b51a9d7c2019be8e10 /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.gz textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.bz2 textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py index c0caf03..8c4bf50 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -26,6 +26,7 @@ from slugify import slugify | |||
| 26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | 26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
| 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| 29 | from training.optimization import get_one_cycle_schedule | ||
| 29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
| 30 | 31 | ||
| 31 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
| @@ -178,10 +179,10 @@ def parse_args(): | |||
| 178 | parser.add_argument( | 179 | parser.add_argument( |
| 179 | "--lr_scheduler", | 180 | "--lr_scheduler", |
| 180 | type=str, | 181 | type=str, |
| 181 | default="cosine_with_restarts", | 182 | default="one_cycle", |
| 182 | help=( | 183 | help=( |
| 183 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 184 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 184 | ' "constant", "constant_with_warmup"]' | 185 | ' "constant", "constant_with_warmup", "one_cycle"]' |
| 185 | ), | 186 | ), |
| 186 | ) | 187 | ) |
| 187 | parser.add_argument( | 188 | parser.add_argument( |
| @@ -585,6 +586,8 @@ def main(): | |||
| 585 | device=accelerator.device | 586 | device=accelerator.device |
| 586 | ) | 587 | ) |
| 587 | 588 | ||
| 589 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 590 | |||
| 588 | if args.gradient_checkpointing: | 591 | if args.gradient_checkpointing: |
| 589 | unet.enable_gradient_checkpointing() | 592 | unet.enable_gradient_checkpointing() |
| 590 | text_encoder.gradient_checkpointing_enable() | 593 | text_encoder.gradient_checkpointing_enable() |
| @@ -784,7 +787,12 @@ def main(): | |||
| 784 | overrode_max_train_steps = True | 787 | overrode_max_train_steps = True |
| 785 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 788 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 786 | 789 | ||
| 787 | if args.lr_scheduler == "cosine_with_restarts": | 790 | if args.lr_scheduler == "one_cycle": |
| 791 | lr_scheduler = get_one_cycle_schedule( | ||
| 792 | optimizer=optimizer, | ||
| 793 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 794 | ) | ||
| 795 | elif args.lr_scheduler == "cosine_with_restarts": | ||
| 788 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 796 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 789 | optimizer=optimizer, | 797 | optimizer=optimizer, |
| 790 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 798 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
