summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-03 17:56:08 +0100
committerVolpeon <git@volpeon.ink>2022-11-03 17:56:08 +0100
commit036e33fbde6bad7c48bb6f6b3d695b7908535c64 (patch)
tree6b0df3deed03ccc8722763b51a9d7c2019be8e10 /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.gz
textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.tar.bz2
textual-inversion-diff-036e33fbde6bad7c48bb6f6b3d695b7908535c64.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py
index c0caf03..8c4bf50 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -26,6 +26,7 @@ from slugify import slugify
26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule
29from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
30 31
31logger = get_logger(__name__) 32logger = get_logger(__name__)
@@ -178,10 +179,10 @@ def parse_args():
178 parser.add_argument( 179 parser.add_argument(
179 "--lr_scheduler", 180 "--lr_scheduler",
180 type=str, 181 type=str,
181 default="cosine_with_restarts", 182 default="one_cycle",
182 help=( 183 help=(
183 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 184 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
184 ' "constant", "constant_with_warmup"]' 185 ' "constant", "constant_with_warmup", "one_cycle"]'
185 ), 186 ),
186 ) 187 )
187 parser.add_argument( 188 parser.add_argument(
@@ -585,6 +586,8 @@ def main():
585 device=accelerator.device 586 device=accelerator.device
586 ) 587 )
587 588
589 unet.set_use_memory_efficient_attention_xformers(True)
590
588 if args.gradient_checkpointing: 591 if args.gradient_checkpointing:
589 unet.enable_gradient_checkpointing() 592 unet.enable_gradient_checkpointing()
590 text_encoder.gradient_checkpointing_enable() 593 text_encoder.gradient_checkpointing_enable()
@@ -784,7 +787,12 @@ def main():
784 overrode_max_train_steps = True 787 overrode_max_train_steps = True
785 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 788 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
786 789
787 if args.lr_scheduler == "cosine_with_restarts": 790 if args.lr_scheduler == "one_cycle":
791 lr_scheduler = get_one_cycle_schedule(
792 optimizer=optimizer,
793 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
794 )
795 elif args.lr_scheduler == "cosine_with_restarts":
788 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 796 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
789 optimizer=optimizer, 797 optimizer=optimizer,
790 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 798 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,