summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-21 13:46:36 +0100
committerVolpeon <git@volpeon.ink>2023-03-21 13:46:36 +0100
commitf5e0e98f6df9260a93fb650a0b97c85eb87b0fd3 (patch)
tree0d061f5fd8950d7ca7e0198731ee58980859dd18 /train_lora.py
parentRestore min SNR (diff)
downloadtextual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.gz
textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.bz2
textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.zip
Fixed SNR weighting, re-enabled xformers
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py36
1 files changed, 9 insertions, 27 deletions
diff --git a/train_lora.py b/train_lora.py
index e65e7be..2a798f3 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -12,8 +12,6 @@ from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRACrossAttnProcessor
17 15
18from util.files import load_config, load_embeddings_from_dir 16from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 17from data.csv import VlpnDataModule, keyword_filter
@@ -426,34 +424,16 @@ def main():
426 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 424 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
427 args.pretrained_model_name_or_path) 425 args.pretrained_model_name_or_path)
428 426
429 unet.to(accelerator.device, dtype=weight_dtype) 427 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
430 text_encoder.to(accelerator.device, dtype=weight_dtype) 428 tokenizer.set_dropout(args.vector_dropout)
431
432 lora_attn_procs = {}
433 for name in unet.attn_processors.keys():
434 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
435 if name.startswith("mid_block"):
436 hidden_size = unet.config.block_out_channels[-1]
437 elif name.startswith("up_blocks"):
438 block_id = int(name[len("up_blocks.")])
439 hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
440 elif name.startswith("down_blocks"):
441 block_id = int(name[len("down_blocks.")])
442 hidden_size = unet.config.block_out_channels[block_id]
443
444 lora_attn_procs[name] = LoRACrossAttnProcessor(
445 hidden_size=hidden_size,
446 cross_attention_dim=cross_attention_dim,
447 rank=args.lora_rank
448 )
449
450 unet.set_attn_processor(lora_attn_procs)
451 429
452 vae.enable_slicing() 430 vae.enable_slicing()
453 vae.set_use_memory_efficient_attention_xformers(True) 431 vae.set_use_memory_efficient_attention_xformers(True)
454 unet.enable_xformers_memory_efficient_attention() 432 unet.enable_xformers_memory_efficient_attention()
455 433
456 lora_layers = AttnProcsLayers(unet.attn_processors) 434 if args.gradient_checkpointing:
435 unet.enable_gradient_checkpointing()
436 text_encoder.gradient_checkpointing_enable()
457 437
458 if args.embeddings_dir is not None: 438 if args.embeddings_dir is not None:
459 embeddings_dir = Path(args.embeddings_dir) 439 embeddings_dir = Path(args.embeddings_dir)
@@ -505,7 +485,6 @@ def main():
505 unet=unet, 485 unet=unet,
506 text_encoder=text_encoder, 486 text_encoder=text_encoder,
507 vae=vae, 487 vae=vae,
508 lora_layers=lora_layers,
509 noise_scheduler=noise_scheduler, 488 noise_scheduler=noise_scheduler,
510 dtype=weight_dtype, 489 dtype=weight_dtype,
511 with_prior_preservation=args.num_class_images != 0, 490 with_prior_preservation=args.num_class_images != 0,
@@ -540,7 +519,10 @@ def main():
540 datamodule.setup() 519 datamodule.setup()
541 520
542 optimizer = create_optimizer( 521 optimizer = create_optimizer(
543 lora_layers.parameters(), 522 itertools.chain(
523 unet.parameters(),
524 text_encoder.parameters(),
525 ),
544 lr=args.learning_rate, 526 lr=args.learning_rate,
545 ) 527 )
546 528