diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 6e72376..e65e7be 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
| 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor |
| 17 | 17 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -292,6 +292,12 @@ def parse_args(): | |||
| 292 | ), | 292 | ), |
| 293 | ) | 293 | ) |
| 294 | parser.add_argument( | 294 | parser.add_argument( |
| 295 | "--lora_rank", | ||
| 296 | type=int, | ||
| 297 | default=256, | ||
| 298 | help="LoRA rank.", | ||
| 299 | ) | ||
| 300 | parser.add_argument( | ||
| 295 | "--sample_frequency", | 301 | "--sample_frequency", |
| 296 | type=int, | 302 | type=int, |
| 297 | default=1, | 303 | default=1, |
| @@ -420,10 +426,6 @@ def main(): | |||
| 420 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 426 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 421 | args.pretrained_model_name_or_path) | 427 | args.pretrained_model_name_or_path) |
| 422 | 428 | ||
| 423 | vae.enable_slicing() | ||
| 424 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 425 | unet.enable_xformers_memory_efficient_attention() | ||
| 426 | |||
| 427 | unet.to(accelerator.device, dtype=weight_dtype) | 429 | unet.to(accelerator.device, dtype=weight_dtype) |
| 428 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 430 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
| 429 | 431 | ||
| @@ -439,11 +441,18 @@ def main(): | |||
| 439 | block_id = int(name[len("down_blocks.")]) | 441 | block_id = int(name[len("down_blocks.")]) |
| 440 | hidden_size = unet.config.block_out_channels[block_id] | 442 | hidden_size = unet.config.block_out_channels[block_id] |
| 441 | 443 | ||
| 442 | lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( | 444 | lora_attn_procs[name] = LoRACrossAttnProcessor( |
| 443 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | 445 | hidden_size=hidden_size, |
| 446 | cross_attention_dim=cross_attention_dim, | ||
| 447 | rank=args.lora_rank | ||
| 444 | ) | 448 | ) |
| 445 | 449 | ||
| 446 | unet.set_attn_processor(lora_attn_procs) | 450 | unet.set_attn_processor(lora_attn_procs) |
| 451 | |||
| 452 | vae.enable_slicing() | ||
| 453 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 454 | unet.enable_xformers_memory_efficient_attention() | ||
| 455 | |||
| 447 | lora_layers = AttnProcsLayers(unet.attn_processors) | 456 | lora_layers = AttnProcsLayers(unet.attn_processors) |
| 448 | 457 | ||
| 449 | if args.embeddings_dir is not None: | 458 | if args.embeddings_dir is not None: |
