diff options
author | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
commit | fe3113451fdde72ddccfc71639f0a2a1e146209a (patch) | |
tree | ba4114faf1bd00a642f97b5e7729ad74213c3b80 /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2 textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 6e72376..e65e7be 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from slugify import slugify | 14 | from slugify import slugify |
15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor |
17 | 17 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
@@ -292,6 +292,12 @@ def parse_args(): | |||
292 | ), | 292 | ), |
293 | ) | 293 | ) |
294 | parser.add_argument( | 294 | parser.add_argument( |
295 | "--lora_rank", | ||
296 | type=int, | ||
297 | default=256, | ||
298 | help="LoRA rank.", | ||
299 | ) | ||
300 | parser.add_argument( | ||
295 | "--sample_frequency", | 301 | "--sample_frequency", |
296 | type=int, | 302 | type=int, |
297 | default=1, | 303 | default=1, |
@@ -420,10 +426,6 @@ def main(): | |||
420 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 426 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
421 | args.pretrained_model_name_or_path) | 427 | args.pretrained_model_name_or_path) |
422 | 428 | ||
423 | vae.enable_slicing() | ||
424 | vae.set_use_memory_efficient_attention_xformers(True) | ||
425 | unet.enable_xformers_memory_efficient_attention() | ||
426 | |||
427 | unet.to(accelerator.device, dtype=weight_dtype) | 429 | unet.to(accelerator.device, dtype=weight_dtype) |
428 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 430 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
429 | 431 | ||
@@ -439,11 +441,18 @@ def main(): | |||
439 | block_id = int(name[len("down_blocks.")]) | 441 | block_id = int(name[len("down_blocks.")]) |
440 | hidden_size = unet.config.block_out_channels[block_id] | 442 | hidden_size = unet.config.block_out_channels[block_id] |
441 | 443 | ||
442 | lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( | 444 | lora_attn_procs[name] = LoRACrossAttnProcessor( |
443 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | 445 | hidden_size=hidden_size, |
446 | cross_attention_dim=cross_attention_dim, | ||
447 | rank=args.lora_rank | ||
444 | ) | 448 | ) |
445 | 449 | ||
446 | unet.set_attn_processor(lora_attn_procs) | 450 | unet.set_attn_processor(lora_attn_procs) |
451 | |||
452 | vae.enable_slicing() | ||
453 | vae.set_use_memory_efficient_attention_xformers(True) | ||
454 | unet.enable_xformers_memory_efficient_attention() | ||
455 | |||
447 | lora_layers = AttnProcsLayers(unet.attn_processors) | 456 | lora_layers = AttnProcsLayers(unet.attn_processors) |
448 | 457 | ||
449 | if args.embeddings_dir is not None: | 458 | if args.embeddings_dir is not None: |