summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-03 17:56:08 +0100
committerVolpeon <git@volpeon.ink>2022-11-03 17:56:08 +0100
commit036e33fbde6bad7c48bb6f6b3d695b7908535c64 (patch)
tree6b0df3deed03ccc8722763b51a9d7c2019be8e10 /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.gz
textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.bz2
textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 115f3aa..578c054 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,6 +25,7 @@ from slugify import slugify
25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule
28from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
29 30
30logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -162,10 +163,10 @@ def parse_args():
162 parser.add_argument( 163 parser.add_argument(
163 "--lr_scheduler", 164 "--lr_scheduler",
164 type=str, 165 type=str,
165 default="cosine_with_restarts", 166 default="one_cycle",
166 help=( 167 help=(
167 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 168 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
168 ' "constant", "constant_with_warmup"]' 169 ' "constant", "constant_with_warmup", "one_cycle"]'
169 ), 170 ),
170 ) 171 )
171 parser.add_argument( 172 parser.add_argument(
@@ -535,6 +536,8 @@ def main():
535 536
536 prompt_processor = PromptProcessor(tokenizer, text_encoder) 537 prompt_processor = PromptProcessor(tokenizer, text_encoder)
537 538
539 unet.set_use_memory_efficient_attention_xformers(True)
540
538 if args.gradient_checkpointing: 541 if args.gradient_checkpointing:
539 text_encoder.gradient_checkpointing_enable() 542 text_encoder.gradient_checkpointing_enable()
540 543
@@ -693,7 +696,12 @@ def main():
693 overrode_max_train_steps = True 696 overrode_max_train_steps = True
694 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 697 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
695 698
696 if args.lr_scheduler == "cosine_with_restarts": 699 if args.lr_scheduler == "one_cycle":
700 lr_scheduler = get_one_cycle_schedule(
701 optimizer=optimizer,
702 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
703 )
704 elif args.lr_scheduler == "cosine_with_restarts":
697 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 705 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
698 optimizer=optimizer, 706 optimizer=optimizer,
699 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 707 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,