diff options
author | Volpeon <git@volpeon.ink> | 2023-02-08 11:38:56 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-08 11:38:56 +0100 |
commit | 347ad308f8223d966793f0421c72432f7e912377 (patch) | |
tree | 2b7319dc37787ce2828101c451987d086dd47360 /train_lora.py | |
parent | Fixed Lora training (diff) | |
download | textual-inversion-diff-347ad308f8223d966793f0421c72432f7e912377.tar.gz textual-inversion-diff-347ad308f8223d966793f0421c72432f7e912377.tar.bz2 textual-inversion-diff-347ad308f8223d966793f0421c72432f7e912377.zip |
Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 8 |
1 files changed, 0 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py index ab1753b..5fd05cc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -178,11 +178,6 @@ def parse_args(): | |||
178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
179 | ) | 179 | ) |
180 | parser.add_argument( | 180 | parser.add_argument( |
181 | "--gradient_checkpointing", | ||
182 | action="store_true", | ||
183 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
184 | ) | ||
185 | parser.add_argument( | ||
186 | "--find_lr", | 181 | "--find_lr", |
187 | action="store_true", | 182 | action="store_true", |
188 | help="Automatically find a learning rate (no training).", | 183 | help="Automatically find a learning rate (no training).", |
@@ -429,9 +424,6 @@ def main(): | |||
429 | vae.set_use_memory_efficient_attention_xformers(True) | 424 | vae.set_use_memory_efficient_attention_xformers(True) |
430 | unet.enable_xformers_memory_efficient_attention() | 425 | unet.enable_xformers_memory_efficient_attention() |
431 | 426 | ||
432 | if args.gradient_checkpointing: | ||
433 | unet.enable_gradient_checkpointing() | ||
434 | |||
435 | unet.to(accelerator.device, dtype=weight_dtype) | 427 | unet.to(accelerator.device, dtype=weight_dtype) |
436 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 428 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
437 | 429 | ||