summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml46
-rw-r--r--train_ti.py38
-rw-r--r--training/functional.py11
3 files changed, 56 insertions, 39 deletions
diff --git a/environment.yaml b/environment.yaml
index 018a9ab..9355f37 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -1,27 +1,25 @@
1name: ldd 1name: ldd
2channels: 2channels:
3 - pytorch 3 - pytorch
4 - nvidia 4 - nvidia
5 - xformers/label/dev 5 - xformers/label/dev
6 - defaults 6 - defaults
7dependencies: 7dependencies:
8 - cudatoolkit=11.7 8 - matplotlib=3.6.2
9 - libcufile=1.4.0.31 9 - numpy=1.23.4
10 - matplotlib=3.6.2 10 - pip=22.3.1
11 - numpy=1.23.4 11 - python=3.10.8
12 - pip=22.3.1 12 - pytorch=2.0.0=*cuda11.8*
13 - python=3.10.8 13 - torchvision=0.15.0
14 - pytorch=1.13.1=*cuda* 14 # - xformers=0.0.17.dev476
15 - torchvision=0.14.1 15 - pip:
16 - xformers=0.0.17.dev466 16 - -e .
17 - pip: 17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
18 - -e . 18 - accelerate==0.17.1
19 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 19 - bitsandbytes==0.37.1
20 - accelerate==0.16.0 20 - python-slugify>=6.1.2
21 - bitsandbytes==0.37.0 21 - safetensors==0.3.0
22 - python-slugify>=6.1.2 22 - setuptools==65.6.3
23 - safetensors==0.2.8 23 - test-tube>=0.7.5
24 - setuptools==65.6.3 24 - transformers==4.27.1
25 - test-tube>=0.7.5 25 - triton==2.0.0
26 - transformers==4.26.1
27 - triton==2.0.0
diff --git a/train_ti.py b/train_ti.py
index 81938c8..fd23517 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -62,12 +62,6 @@ def parse_args():
62 help="The name of the current project.", 62 help="The name of the current project.",
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
65 "--skip_first",
66 type=int,
67 default=0,
68 help="Tokens to skip training for.",
69 )
70 parser.add_argument(
71 "--placeholder_tokens", 65 "--placeholder_tokens",
72 type=str, 66 type=str,
73 nargs='*', 67 nargs='*',
@@ -80,6 +74,13 @@ def parse_args():
80 help="A token to use as initializer word." 74 help="A token to use as initializer word."
81 ) 75 )
82 parser.add_argument( 76 parser.add_argument(
77 "--alias_tokens",
78 type=str,
79 nargs='*',
80 default=[],
81 help="Tokens to create an alias for."
82 )
83 parser.add_argument(
83 "--num_vectors", 84 "--num_vectors",
84 type=int, 85 type=int,
85 nargs='*', 86 nargs='*',
@@ -420,7 +421,7 @@ def parse_args():
420 ) 421 )
421 parser.add_argument( 422 parser.add_argument(
422 "--emb_decay", 423 "--emb_decay",
423 default=1e-2, 424 default=1e2,
424 type=float, 425 type=float,
425 help="Embedding decay factor." 426 help="Embedding decay factor."
426 ) 427 )
@@ -482,6 +483,9 @@ def parse_args():
482 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 483 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors):
483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 484 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
484 485
486 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
487 raise ValueError("--alias_tokens must be a list with an even number of items")
488
485 if args.sequential: 489 if args.sequential:
486 if isinstance(args.train_data_template, str): 490 if isinstance(args.train_data_template, str):
487 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 491 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
@@ -543,8 +547,8 @@ def main():
543 tokenizer.set_dropout(args.vector_dropout) 547 tokenizer.set_dropout(args.vector_dropout)
544 548
545 vae.enable_slicing() 549 vae.enable_slicing()
546 vae.set_use_memory_efficient_attention_xformers(True) 550 # vae.set_use_memory_efficient_attention_xformers(True)
547 unet.enable_xformers_memory_efficient_attention() 551 # unet.enable_xformers_memory_efficient_attention()
548 # unet = torch.compile(unet) 552 # unet = torch.compile(unet)
549 553
550 if args.gradient_checkpointing: 554 if args.gradient_checkpointing:
@@ -559,6 +563,19 @@ def main():
559 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 563 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
560 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 564 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
561 565
566 if len(args.alias_tokens) != 0:
567 alias_placeholder_tokens = args.alias_tokens[::2]
568 alias_initializer_tokens = args.alias_tokens[1::2]
569
570 added_tokens, added_ids = add_placeholder_tokens(
571 tokenizer=tokenizer,
572 embeddings=embeddings,
573 placeholder_tokens=alias_placeholder_tokens,
574 initializer_tokens=alias_initializer_tokens
575 )
576 embeddings.persist()
577 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}")
578
562 if args.scale_lr: 579 if args.scale_lr:
563 args.learning_rate = ( 580 args.learning_rate = (
564 args.learning_rate * args.gradient_accumulation_steps * 581 args.learning_rate * args.gradient_accumulation_steps *
@@ -633,9 +650,6 @@ def main():
633 ) 650 )
634 651
635 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 652 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
636 if i < args.skip_first:
637 return
638
639 if len(placeholder_tokens) == 1: 653 if len(placeholder_tokens) == 1:
640 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" 654 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}"
641 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" 655 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png"
diff --git a/training/functional.py b/training/functional.py
index 4565612..2d6553a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -309,8 +309,12 @@ def loss_step(
309 # Get the target for loss depending on the prediction type 309 # Get the target for loss depending on the prediction type
310 if noise_scheduler.config.prediction_type == "epsilon": 310 if noise_scheduler.config.prediction_type == "epsilon":
311 target = noise 311 target = noise
312 snr_weights = 1
312 elif noise_scheduler.config.prediction_type == "v_prediction": 313 elif noise_scheduler.config.prediction_type == "v_prediction":
313 target = noise_scheduler.get_velocity(latents, noise, timesteps) 314 target = noise_scheduler.get_velocity(latents, noise, timesteps)
315 snr = target / (1 - target)
316 snr /= snr + 1
317 snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device))
314 else: 318 else:
315 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 319 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
316 320
@@ -320,16 +324,17 @@ def loss_step(
320 target, target_prior = torch.chunk(target, 2, dim=0) 324 target, target_prior = torch.chunk(target, 2, dim=0)
321 325
322 # Compute instance loss 326 # Compute instance loss
323 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 327 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
324 328
325 # Compute prior loss 329 # Compute prior loss
326 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 330 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
327 331
328 # Add the prior loss to the instance loss. 332 # Add the prior loss to the instance loss.
329 loss = loss + prior_loss_weight * prior_loss 333 loss = loss + prior_loss_weight * prior_loss
330 else: 334 else:
331 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 335 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
332 336
337 loss = (snr_weights * loss).mean()
333 acc = (model_pred == target).float().mean() 338 acc = (model_pred == target).float().mean()
334 339
335 return loss, acc, bsz 340 return loss, acc, bsz