diff options
author | Volpeon <git@volpeon.ink> | 2023-03-17 15:18:20 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-17 15:18:20 +0100 |
commit | 8abbd633d8ee7500058b2f1f69a6d6611b5a4450 (patch) | |
tree | f60d6e384966ba05354b30f08a32a38279b56165 | |
parent | Update (diff) | |
download | textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.tar.gz textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.tar.bz2 textual-inversion-diff-8abbd633d8ee7500058b2f1f69a6d6611b5a4450.zip |
Test: https://arxiv.org/pdf/2303.09556.pdf
-rw-r--r-- | environment.yaml | 46 | ||||
-rw-r--r-- | train_ti.py | 38 | ||||
-rw-r--r-- | training/functional.py | 11 |
3 files changed, 56 insertions, 39 deletions
diff --git a/environment.yaml b/environment.yaml index 018a9ab..9355f37 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -1,27 +1,25 @@ | |||
1 | name: ldd | 1 | name: ldd |
2 | channels: | 2 | channels: |
3 | - pytorch | 3 | - pytorch |
4 | - nvidia | 4 | - nvidia |
5 | - xformers/label/dev | 5 | - xformers/label/dev |
6 | - defaults | 6 | - defaults |
7 | dependencies: | 7 | dependencies: |
8 | - cudatoolkit=11.7 | 8 | - matplotlib=3.6.2 |
9 | - libcufile=1.4.0.31 | 9 | - numpy=1.23.4 |
10 | - matplotlib=3.6.2 | 10 | - pip=22.3.1 |
11 | - numpy=1.23.4 | 11 | - python=3.10.8 |
12 | - pip=22.3.1 | 12 | - pytorch=2.0.0=*cuda11.8* |
13 | - python=3.10.8 | 13 | - torchvision=0.15.0 |
14 | - pytorch=1.13.1=*cuda* | 14 | # - xformers=0.0.17.dev476 |
15 | - torchvision=0.14.1 | 15 | - pip: |
16 | - xformers=0.0.17.dev466 | 16 | - -e . |
17 | - pip: | 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
18 | - -e . | 18 | - accelerate==0.17.1 |
19 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 19 | - bitsandbytes==0.37.1 |
20 | - accelerate==0.16.0 | 20 | - python-slugify>=6.1.2 |
21 | - bitsandbytes==0.37.0 | 21 | - safetensors==0.3.0 |
22 | - python-slugify>=6.1.2 | 22 | - setuptools==65.6.3 |
23 | - safetensors==0.2.8 | 23 | - test-tube>=0.7.5 |
24 | - setuptools==65.6.3 | 24 | - transformers==4.27.1 |
25 | - test-tube>=0.7.5 | 25 | - triton==2.0.0 |
26 | - transformers==4.26.1 | ||
27 | - triton==2.0.0 | ||
diff --git a/train_ti.py b/train_ti.py index 81938c8..fd23517 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -62,12 +62,6 @@ def parse_args(): | |||
62 | help="The name of the current project.", | 62 | help="The name of the current project.", |
63 | ) | 63 | ) |
64 | parser.add_argument( | 64 | parser.add_argument( |
65 | "--skip_first", | ||
66 | type=int, | ||
67 | default=0, | ||
68 | help="Tokens to skip training for.", | ||
69 | ) | ||
70 | parser.add_argument( | ||
71 | "--placeholder_tokens", | 65 | "--placeholder_tokens", |
72 | type=str, | 66 | type=str, |
73 | nargs='*', | 67 | nargs='*', |
@@ -80,6 +74,13 @@ def parse_args(): | |||
80 | help="A token to use as initializer word." | 74 | help="A token to use as initializer word." |
81 | ) | 75 | ) |
82 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--alias_tokens", | ||
78 | type=str, | ||
79 | nargs='*', | ||
80 | default=[], | ||
81 | help="Tokens to create an alias for." | ||
82 | ) | ||
83 | parser.add_argument( | ||
83 | "--num_vectors", | 84 | "--num_vectors", |
84 | type=int, | 85 | type=int, |
85 | nargs='*', | 86 | nargs='*', |
@@ -420,7 +421,7 @@ def parse_args(): | |||
420 | ) | 421 | ) |
421 | parser.add_argument( | 422 | parser.add_argument( |
422 | "--emb_decay", | 423 | "--emb_decay", |
423 | default=1e-2, | 424 | default=1e2, |
424 | type=float, | 425 | type=float, |
425 | help="Embedding decay factor." | 426 | help="Embedding decay factor." |
426 | ) | 427 | ) |
@@ -482,6 +483,9 @@ def parse_args(): | |||
482 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 483 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 484 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
484 | 485 | ||
486 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
487 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
488 | |||
485 | if args.sequential: | 489 | if args.sequential: |
486 | if isinstance(args.train_data_template, str): | 490 | if isinstance(args.train_data_template, str): |
487 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 491 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
@@ -543,8 +547,8 @@ def main(): | |||
543 | tokenizer.set_dropout(args.vector_dropout) | 547 | tokenizer.set_dropout(args.vector_dropout) |
544 | 548 | ||
545 | vae.enable_slicing() | 549 | vae.enable_slicing() |
546 | vae.set_use_memory_efficient_attention_xformers(True) | 550 | # vae.set_use_memory_efficient_attention_xformers(True) |
547 | unet.enable_xformers_memory_efficient_attention() | 551 | # unet.enable_xformers_memory_efficient_attention() |
548 | # unet = torch.compile(unet) | 552 | # unet = torch.compile(unet) |
549 | 553 | ||
550 | if args.gradient_checkpointing: | 554 | if args.gradient_checkpointing: |
@@ -559,6 +563,19 @@ def main(): | |||
559 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 563 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
560 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 564 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
561 | 565 | ||
566 | if len(args.alias_tokens) != 0: | ||
567 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
568 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
569 | |||
570 | added_tokens, added_ids = add_placeholder_tokens( | ||
571 | tokenizer=tokenizer, | ||
572 | embeddings=embeddings, | ||
573 | placeholder_tokens=alias_placeholder_tokens, | ||
574 | initializer_tokens=alias_initializer_tokens | ||
575 | ) | ||
576 | embeddings.persist() | ||
577 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | ||
578 | |||
562 | if args.scale_lr: | 579 | if args.scale_lr: |
563 | args.learning_rate = ( | 580 | args.learning_rate = ( |
564 | args.learning_rate * args.gradient_accumulation_steps * | 581 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -633,9 +650,6 @@ def main(): | |||
633 | ) | 650 | ) |
634 | 651 | ||
635 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 652 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
636 | if i < args.skip_first: | ||
637 | return | ||
638 | |||
639 | if len(placeholder_tokens) == 1: | 653 | if len(placeholder_tokens) == 1: |
640 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" | 654 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" |
641 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" | 655 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" |
diff --git a/training/functional.py b/training/functional.py index 4565612..2d6553a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -309,8 +309,12 @@ def loss_step( | |||
309 | # Get the target for loss depending on the prediction type | 309 | # Get the target for loss depending on the prediction type |
310 | if noise_scheduler.config.prediction_type == "epsilon": | 310 | if noise_scheduler.config.prediction_type == "epsilon": |
311 | target = noise | 311 | target = noise |
312 | snr_weights = 1 | ||
312 | elif noise_scheduler.config.prediction_type == "v_prediction": | 313 | elif noise_scheduler.config.prediction_type == "v_prediction": |
313 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
315 | snr = target / (1 - target) | ||
316 | snr /= snr + 1 | ||
317 | snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) | ||
314 | else: | 318 | else: |
315 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 319 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
316 | 320 | ||
@@ -320,16 +324,17 @@ def loss_step( | |||
320 | target, target_prior = torch.chunk(target, 2, dim=0) | 324 | target, target_prior = torch.chunk(target, 2, dim=0) |
321 | 325 | ||
322 | # Compute instance loss | 326 | # Compute instance loss |
323 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 327 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
324 | 328 | ||
325 | # Compute prior loss | 329 | # Compute prior loss |
326 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 330 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
327 | 331 | ||
328 | # Add the prior loss to the instance loss. | 332 | # Add the prior loss to the instance loss. |
329 | loss = loss + prior_loss_weight * prior_loss | 333 | loss = loss + prior_loss_weight * prior_loss |
330 | else: | 334 | else: |
331 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 335 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
332 | 336 | ||
337 | loss = (snr_weights * loss).mean() | ||
333 | acc = (model_pred == target).float().mean() | 338 | acc = (model_pred == target).float().mean() |
334 | 339 | ||
335 | return loss, acc, bsz | 340 | return loss, acc, bsz |