diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-21 13:46:36 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-21 13:46:36 +0100 |
| commit | f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3 (patch) | |
| tree | 0d061f5fd8950d7ca7e0198731ee58980859dd18 /train_lora.py | |
| parent | Restore min SNR (diff) | |
| download | textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.gz textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.bz2 textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.zip | |
Fixed SNR weighting, re-enabled xformers
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 36 |
1 files changed, 9 insertions, 27 deletions
diff --git a/train_lora.py b/train_lora.py index e65e7be..2a798f3 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -12,8 +12,6 @@ from accelerate import Accelerator | |||
| 12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | from diffusers.loaders import AttnProcsLayers | ||
| 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | ||
| 17 | 15 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 16 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 17 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -426,34 +424,16 @@ def main(): | |||
| 426 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 427 | args.pretrained_model_name_or_path) | 425 | args.pretrained_model_name_or_path) |
| 428 | 426 | ||
| 429 | unet.to(accelerator.device, dtype=weight_dtype) | 427 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 430 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 428 | tokenizer.set_dropout(args.vector_dropout) |
| 431 | |||
| 432 | lora_attn_procs = {} | ||
| 433 | for name in unet.attn_processors.keys(): | ||
| 434 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
| 435 | if name.startswith("mid_block"): | ||
| 436 | hidden_size = unet.config.block_out_channels[-1] | ||
| 437 | elif name.startswith("up_blocks"): | ||
| 438 | block_id = int(name[len("up_blocks.")]) | ||
| 439 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
| 440 | elif name.startswith("down_blocks"): | ||
| 441 | block_id = int(name[len("down_blocks.")]) | ||
| 442 | hidden_size = unet.config.block_out_channels[block_id] | ||
| 443 | |||
| 444 | lora_attn_procs[name] = LoRACrossAttnProcessor( | ||
| 445 | hidden_size=hidden_size, | ||
| 446 | cross_attention_dim=cross_attention_dim, | ||
| 447 | rank=args.lora_rank | ||
| 448 | ) | ||
| 449 | |||
| 450 | unet.set_attn_processor(lora_attn_procs) | ||
| 451 | 429 | ||
| 452 | vae.enable_slicing() | 430 | vae.enable_slicing() |
| 453 | vae.set_use_memory_efficient_attention_xformers(True) | 431 | vae.set_use_memory_efficient_attention_xformers(True) |
| 454 | unet.enable_xformers_memory_efficient_attention() | 432 | unet.enable_xformers_memory_efficient_attention() |
| 455 | 433 | ||
| 456 | lora_layers = AttnProcsLayers(unet.attn_processors) | 434 | if args.gradient_checkpointing: |
| 435 | unet.enable_gradient_checkpointing() | ||
| 436 | text_encoder.gradient_checkpointing_enable() | ||
| 457 | 437 | ||
| 458 | if args.embeddings_dir is not None: | 438 | if args.embeddings_dir is not None: |
| 459 | embeddings_dir = Path(args.embeddings_dir) | 439 | embeddings_dir = Path(args.embeddings_dir) |
| @@ -505,7 +485,6 @@ def main(): | |||
| 505 | unet=unet, | 485 | unet=unet, |
| 506 | text_encoder=text_encoder, | 486 | text_encoder=text_encoder, |
| 507 | vae=vae, | 487 | vae=vae, |
| 508 | lora_layers=lora_layers, | ||
| 509 | noise_scheduler=noise_scheduler, | 488 | noise_scheduler=noise_scheduler, |
| 510 | dtype=weight_dtype, | 489 | dtype=weight_dtype, |
| 511 | with_prior_preservation=args.num_class_images != 0, | 490 | with_prior_preservation=args.num_class_images != 0, |
| @@ -540,7 +519,10 @@ def main(): | |||
| 540 | datamodule.setup() | 519 | datamodule.setup() |
| 541 | 520 | ||
| 542 | optimizer = create_optimizer( | 521 | optimizer = create_optimizer( |
| 543 | lora_layers.parameters(), | 522 | itertools.chain( |
| 523 | unet.parameters(), | ||
| 524 | text_encoder.parameters(), | ||
| 525 | ), | ||
| 544 | lr=args.learning_rate, | 526 | lr=args.learning_rate, |
| 545 | ) | 527 | ) |
| 546 | 528 | ||
