diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 36 |
1 files changed, 9 insertions, 27 deletions
diff --git a/train_lora.py b/train_lora.py index e65e7be..2a798f3 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -12,8 +12,6 @@ from accelerate import Accelerator | |||
12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from slugify import slugify | 14 | from slugify import slugify |
15 | from diffusers.loaders import AttnProcsLayers | ||
16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | ||
17 | 15 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 16 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 17 | from data.csv import VlpnDataModule, keyword_filter |
@@ -426,34 +424,16 @@ def main(): | |||
426 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
427 | args.pretrained_model_name_or_path) | 425 | args.pretrained_model_name_or_path) |
428 | 426 | ||
429 | unet.to(accelerator.device, dtype=weight_dtype) | 427 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
430 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 428 | tokenizer.set_dropout(args.vector_dropout) |
431 | |||
432 | lora_attn_procs = {} | ||
433 | for name in unet.attn_processors.keys(): | ||
434 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
435 | if name.startswith("mid_block"): | ||
436 | hidden_size = unet.config.block_out_channels[-1] | ||
437 | elif name.startswith("up_blocks"): | ||
438 | block_id = int(name[len("up_blocks.")]) | ||
439 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
440 | elif name.startswith("down_blocks"): | ||
441 | block_id = int(name[len("down_blocks.")]) | ||
442 | hidden_size = unet.config.block_out_channels[block_id] | ||
443 | |||
444 | lora_attn_procs[name] = LoRACrossAttnProcessor( | ||
445 | hidden_size=hidden_size, | ||
446 | cross_attention_dim=cross_attention_dim, | ||
447 | rank=args.lora_rank | ||
448 | ) | ||
449 | |||
450 | unet.set_attn_processor(lora_attn_procs) | ||
451 | 429 | ||
452 | vae.enable_slicing() | 430 | vae.enable_slicing() |
453 | vae.set_use_memory_efficient_attention_xformers(True) | 431 | vae.set_use_memory_efficient_attention_xformers(True) |
454 | unet.enable_xformers_memory_efficient_attention() | 432 | unet.enable_xformers_memory_efficient_attention() |
455 | 433 | ||
456 | lora_layers = AttnProcsLayers(unet.attn_processors) | 434 | if args.gradient_checkpointing: |
435 | unet.enable_gradient_checkpointing() | ||
436 | text_encoder.gradient_checkpointing_enable() | ||
457 | 437 | ||
458 | if args.embeddings_dir is not None: | 438 | if args.embeddings_dir is not None: |
459 | embeddings_dir = Path(args.embeddings_dir) | 439 | embeddings_dir = Path(args.embeddings_dir) |
@@ -505,7 +485,6 @@ def main(): | |||
505 | unet=unet, | 485 | unet=unet, |
506 | text_encoder=text_encoder, | 486 | text_encoder=text_encoder, |
507 | vae=vae, | 487 | vae=vae, |
508 | lora_layers=lora_layers, | ||
509 | noise_scheduler=noise_scheduler, | 488 | noise_scheduler=noise_scheduler, |
510 | dtype=weight_dtype, | 489 | dtype=weight_dtype, |
511 | with_prior_preservation=args.num_class_images != 0, | 490 | with_prior_preservation=args.num_class_images != 0, |
@@ -540,7 +519,10 @@ def main(): | |||
540 | datamodule.setup() | 519 | datamodule.setup() |
541 | 520 | ||
542 | optimizer = create_optimizer( | 521 | optimizer = create_optimizer( |
543 | lora_layers.parameters(), | 522 | itertools.chain( |
523 | unet.parameters(), | ||
524 | text_encoder.parameters(), | ||
525 | ), | ||
544 | lr=args.learning_rate, | 526 | lr=args.learning_rate, |
545 | ) | 527 | ) |
546 | 528 | ||