diff options
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
-rw-r--r-- | train_lora.py | 23 | ||||
-rw-r--r-- | train_ti.py | 8 | ||||
-rw-r--r-- | training/functional.py | 6 | ||||
-rw-r--r-- | training/strategy/lora.py | 25 |
5 files changed, 37 insertions, 27 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f426de1..4505a2a 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -394,7 +394,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
394 | sag_scale: float = 0.75, | 394 | sag_scale: float = 0.75, |
395 | eta: float = 0.0, | 395 | eta: float = 0.0, |
396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
397 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = "noise", | 397 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, |
398 | output_type: str = "pil", | 398 | output_type: str = "pil", |
399 | return_dict: bool = True, | 399 | return_dict: bool = True, |
400 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 400 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
diff --git a/train_lora.py b/train_lora.py index 6e72376..e65e7be 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from slugify import slugify | 14 | from slugify import slugify |
15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor |
17 | 17 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
@@ -292,6 +292,12 @@ def parse_args(): | |||
292 | ), | 292 | ), |
293 | ) | 293 | ) |
294 | parser.add_argument( | 294 | parser.add_argument( |
295 | "--lora_rank", | ||
296 | type=int, | ||
297 | default=256, | ||
298 | help="LoRA rank.", | ||
299 | ) | ||
300 | parser.add_argument( | ||
295 | "--sample_frequency", | 301 | "--sample_frequency", |
296 | type=int, | 302 | type=int, |
297 | default=1, | 303 | default=1, |
@@ -420,10 +426,6 @@ def main(): | |||
420 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 426 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
421 | args.pretrained_model_name_or_path) | 427 | args.pretrained_model_name_or_path) |
422 | 428 | ||
423 | vae.enable_slicing() | ||
424 | vae.set_use_memory_efficient_attention_xformers(True) | ||
425 | unet.enable_xformers_memory_efficient_attention() | ||
426 | |||
427 | unet.to(accelerator.device, dtype=weight_dtype) | 429 | unet.to(accelerator.device, dtype=weight_dtype) |
428 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 430 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
429 | 431 | ||
@@ -439,11 +441,18 @@ def main(): | |||
439 | block_id = int(name[len("down_blocks.")]) | 441 | block_id = int(name[len("down_blocks.")]) |
440 | hidden_size = unet.config.block_out_channels[block_id] | 442 | hidden_size = unet.config.block_out_channels[block_id] |
441 | 443 | ||
442 | lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( | 444 | lora_attn_procs[name] = LoRACrossAttnProcessor( |
443 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | 445 | hidden_size=hidden_size, |
446 | cross_attention_dim=cross_attention_dim, | ||
447 | rank=args.lora_rank | ||
444 | ) | 448 | ) |
445 | 449 | ||
446 | unet.set_attn_processor(lora_attn_procs) | 450 | unet.set_attn_processor(lora_attn_procs) |
451 | |||
452 | vae.enable_slicing() | ||
453 | vae.set_use_memory_efficient_attention_xformers(True) | ||
454 | unet.enable_xformers_memory_efficient_attention() | ||
455 | |||
447 | lora_layers = AttnProcsLayers(unet.attn_processors) | 456 | lora_layers = AttnProcsLayers(unet.attn_processors) |
448 | 457 | ||
449 | if args.embeddings_dir is not None: | 458 | if args.embeddings_dir is not None: |
diff --git a/train_ti.py b/train_ti.py index b9d6e56..81938c8 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -476,13 +476,10 @@ def parse_args(): | |||
476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
478 | 478 | ||
479 | if args.num_vectors is None: | ||
480 | args.num_vectors = 1 | ||
481 | |||
482 | if isinstance(args.num_vectors, int): | 479 | if isinstance(args.num_vectors, int): |
483 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 480 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
484 | 481 | ||
485 | if len(args.placeholder_tokens) != len(args.num_vectors): | 482 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
487 | 484 | ||
488 | if args.sequential: | 485 | if args.sequential: |
@@ -491,6 +488,9 @@ def parse_args(): | |||
491 | 488 | ||
492 | if len(args.placeholder_tokens) != len(args.train_data_template): | 489 | if len(args.placeholder_tokens) != len(args.train_data_template): |
493 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 490 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
491 | |||
492 | if args.num_vectors is None: | ||
493 | args.num_vectors = [None] * len(args.placeholder_tokens) | ||
494 | else: | 494 | else: |
495 | if isinstance(args.train_data_template, list): | 495 | if isinstance(args.train_data_template, list): |
496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | 496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") |
diff --git a/training/functional.py b/training/functional.py index 27a43c2..4565612 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -231,12 +231,16 @@ def add_placeholder_tokens( | |||
231 | embeddings: ManagedCLIPTextEmbeddings, | 231 | embeddings: ManagedCLIPTextEmbeddings, |
232 | placeholder_tokens: list[str], | 232 | placeholder_tokens: list[str], |
233 | initializer_tokens: list[str], | 233 | initializer_tokens: list[str], |
234 | num_vectors: Union[list[int], int] | 234 | num_vectors: Optional[Union[list[int], int]] = None, |
235 | ): | 235 | ): |
236 | initializer_token_ids = [ | 236 | initializer_token_ids = [ |
237 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) |
238 | for token in initializer_tokens | 238 | for token in initializer_tokens |
239 | ] | 239 | ] |
240 | |||
241 | if num_vectors is None: | ||
242 | num_vectors = [len(ids) for ids in initializer_token_ids] | ||
243 | |||
240 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | 244 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) |
241 | 245 | ||
242 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ccec215..cab5e4c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -11,10 +11,7 @@ from transformers import CLIPTextModel | |||
11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
12 | from diffusers.loaders import AttnProcsLayers | 12 | from diffusers.loaders import AttnProcsLayers |
13 | 13 | ||
14 | from slugify import slugify | ||
15 | |||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 14 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.util import EMAModel | ||
18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 15 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
19 | 16 | ||
20 | 17 | ||
@@ -41,16 +38,9 @@ def lora_strategy_callbacks( | |||
41 | sample_output_dir.mkdir(parents=True, exist_ok=True) | 38 | sample_output_dir.mkdir(parents=True, exist_ok=True) |
42 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | 39 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) |
43 | 40 | ||
44 | weight_dtype = torch.float32 | ||
45 | if accelerator.state.mixed_precision == "fp16": | ||
46 | weight_dtype = torch.float16 | ||
47 | elif accelerator.state.mixed_precision == "bf16": | ||
48 | weight_dtype = torch.bfloat16 | ||
49 | |||
50 | save_samples_ = partial( | 41 | save_samples_ = partial( |
51 | save_samples, | 42 | save_samples, |
52 | accelerator=accelerator, | 43 | accelerator=accelerator, |
53 | unet=unet, | ||
54 | text_encoder=text_encoder, | 44 | text_encoder=text_encoder, |
55 | tokenizer=tokenizer, | 45 | tokenizer=tokenizer, |
56 | vae=vae, | 46 | vae=vae, |
@@ -83,20 +73,27 @@ def lora_strategy_callbacks( | |||
83 | yield | 73 | yield |
84 | 74 | ||
85 | def on_before_optimize(lr: float, epoch: int): | 75 | def on_before_optimize(lr: float, epoch: int): |
86 | if accelerator.sync_gradients: | 76 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) |
87 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | ||
88 | 77 | ||
89 | @torch.no_grad() | 78 | @torch.no_grad() |
90 | def on_checkpoint(step, postfix): | 79 | def on_checkpoint(step, postfix): |
91 | print(f"Saving checkpoint for step {step}...") | 80 | print(f"Saving checkpoint for step {step}...") |
92 | 81 | ||
93 | unet_ = accelerator.unwrap_model(unet, False) | 82 | unet_ = accelerator.unwrap_model(unet, False) |
94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") | 83 | unet_.save_attn_procs( |
84 | checkpoint_output_dir / f"{step}_{postfix}", | ||
85 | safe_serialization=True | ||
86 | ) | ||
95 | del unet_ | 87 | del unet_ |
96 | 88 | ||
97 | @torch.no_grad() | 89 | @torch.no_grad() |
98 | def on_sample(step): | 90 | def on_sample(step): |
99 | save_samples_(step=step) | 91 | unet_ = accelerator.unwrap_model(unet, False) |
92 | save_samples_(step=step, unet=unet_) | ||
93 | del unet_ | ||
94 | |||
95 | if torch.cuda.is_available(): | ||
96 | torch.cuda.empty_cache() | ||
100 | 97 | ||
101 | return TrainingCallbacks( | 98 | return TrainingCallbacks( |
102 | on_prepare=on_prepare, | 99 | on_prepare=on_prepare, |