diff options
author | Volpeon <git@volpeon.ink> | 2022-11-03 17:56:08 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-03 17:56:08 +0100 |
commit | 036e33fbde6bad7c48bb6f6b3d695b7908535c64 (patch) | |
tree | 6b0df3deed03ccc8722763b51a9d7c2019be8e10 /textual_inversion.py | |
parent | Update (diff) | |
download | textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.gz textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.bz2 textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.zip |
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 115f3aa..578c054 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -25,6 +25,7 @@ from slugify import slugify | |||
25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | 25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
28 | from training.optimization import get_one_cycle_schedule | ||
28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
29 | 30 | ||
30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -162,10 +163,10 @@ def parse_args(): | |||
162 | parser.add_argument( | 163 | parser.add_argument( |
163 | "--lr_scheduler", | 164 | "--lr_scheduler", |
164 | type=str, | 165 | type=str, |
165 | default="cosine_with_restarts", | 166 | default="one_cycle", |
166 | help=( | 167 | help=( |
167 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 168 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
168 | ' "constant", "constant_with_warmup"]' | 169 | ' "constant", "constant_with_warmup", "one_cycle"]' |
169 | ), | 170 | ), |
170 | ) | 171 | ) |
171 | parser.add_argument( | 172 | parser.add_argument( |
@@ -535,6 +536,8 @@ def main(): | |||
535 | 536 | ||
536 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 537 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
537 | 538 | ||
539 | unet.set_use_memory_efficient_attention_xformers(True) | ||
540 | |||
538 | if args.gradient_checkpointing: | 541 | if args.gradient_checkpointing: |
539 | text_encoder.gradient_checkpointing_enable() | 542 | text_encoder.gradient_checkpointing_enable() |
540 | 543 | ||
@@ -693,7 +696,12 @@ def main(): | |||
693 | overrode_max_train_steps = True | 696 | overrode_max_train_steps = True |
694 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 697 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
695 | 698 | ||
696 | if args.lr_scheduler == "cosine_with_restarts": | 699 | if args.lr_scheduler == "one_cycle": |
700 | lr_scheduler = get_one_cycle_schedule( | ||
701 | optimizer=optimizer, | ||
702 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
703 | ) | ||
704 | elif args.lr_scheduler == "cosine_with_restarts": | ||
697 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 705 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
698 | optimizer=optimizer, | 706 | optimizer=optimizer, |
699 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 707 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |