From f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Mar 2023 13:46:36 +0100 Subject: Fixed SNR weighting, re-enabled xformers --- train_lora.py | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index e65e7be..2a798f3 100644 --- a/train_lora.py +++ b/train_lora.py @@ -12,8 +12,6 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify -from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -426,34 +424,16 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) - - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.lora_rank - ) - - unet.set_attn_processor(lora_attn_procs) + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() - lora_layers = AttnProcsLayers(unet.attn_processors) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) @@ -505,7 +485,6 @@ def main(): unet=unet, text_encoder=text_encoder, vae=vae, - lora_layers=lora_layers, noise_scheduler=noise_scheduler, dtype=weight_dtype, with_prior_preservation=args.num_class_images != 0, @@ -540,7 +519,10 @@ def main(): datamodule.setup() optimizer = create_optimizer( - lora_layers.parameters(), + itertools.chain( + unet.parameters(), + text_encoder.parameters(), + ), lr=args.learning_rate, ) -- cgit v1.2.3-54-g00ecf