summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml3
-rw-r--r--models/clip/embeddings.py8
-rw-r--r--train_dreambooth.py11
-rw-r--r--train_lora.py11
-rw-r--r--train_ti.py21
-rw-r--r--training/functional.py7
-rw-r--r--training/strategy/ti.py15
7 files changed, 31 insertions, 45 deletions
diff --git a/environment.yaml b/environment.yaml
index 57624a3..1e6ac60 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -13,16 +13,15 @@ dependencies:
13 - python=3.10.8 13 - python=3.10.8
14 - pytorch=1.13.1=*cuda* 14 - pytorch=1.13.1=*cuda*
15 - torchvision=0.14.1 15 - torchvision=0.14.1
16 - xformers=0.0.17.dev461
16 - pip: 17 - pip:
17 - -e . 18 - -e .
18 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 19 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
19 - accelerate==0.16.0 20 - accelerate==0.16.0
20 - bitsandbytes==0.37.0 21 - bitsandbytes==0.37.0
21 - lion-pytorch==0.0.6
22 - python-slugify>=6.1.2 22 - python-slugify>=6.1.2
23 - safetensors==0.2.8 23 - safetensors==0.2.8
24 - setuptools==65.6.3 24 - setuptools==65.6.3
25 - test-tube>=0.7.5 25 - test-tube>=0.7.5
26 - transformers==4.26.1 26 - transformers==4.26.1
27 - triton==2.0.0a2 27 - triton==2.0.0a2
28 - xformers==0.0.17.dev451
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 6c41c33..734730e 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -98,14 +98,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
98 98
99 return embeds 99 return embeds
100 100
101 def normalize(self, target: float = 0.4, lambda_: float = 1.0):
102 if lambda_ == 0:
103 return
104
105 w = self.temp_token_embedding.weight
106 norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True)
107 w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm))
108
109 def forward( 101 def forward(
110 self, 102 self,
111 input_ids: Optional[torch.LongTensor] = None, 103 input_ids: Optional[torch.LongTensor] = None,
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e039df0..431ff3d 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -288,7 +288,7 @@ def parse_args():
288 "--optimizer", 288 "--optimizer",
289 type=str, 289 type=str,
290 default="adam", 290 default="adam",
291 help='Optimizer to use ["adam", "adam8bit", "lion"]' 291 help='Optimizer to use ["adam", "adam8bit"]'
292 ) 292 )
293 parser.add_argument( 293 parser.add_argument(
294 "--adam_beta1", 294 "--adam_beta1",
@@ -459,7 +459,7 @@ def main():
459 save_args(output_dir, args) 459 save_args(output_dir, args)
460 460
461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
462 args.pretrained_model_name_or_path, noise_scheduler="deis") 462 args.pretrained_model_name_or_path)
463 463
464 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 464 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
465 tokenizer.set_dropout(args.vector_dropout) 465 tokenizer.set_dropout(args.vector_dropout)
@@ -513,13 +513,6 @@ def main():
513 eps=args.adam_epsilon, 513 eps=args.adam_epsilon,
514 amsgrad=args.adam_amsgrad, 514 amsgrad=args.adam_amsgrad,
515 ) 515 )
516 elif args.optimizer == 'lion':
517 try:
518 from lion_pytorch import Lion
519 except ImportError:
520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
521
522 create_optimizer = partial(Lion, use_triton=True)
523 else: 516 else:
524 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 517 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
525 518
diff --git a/train_lora.py b/train_lora.py
index db5330a..a06591d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -248,7 +248,7 @@ def parse_args():
248 "--optimizer", 248 "--optimizer",
249 type=str, 249 type=str,
250 default="adam", 250 default="adam",
251 help='Optimizer to use ["adam", "adam8bit", "lion"]' 251 help='Optimizer to use ["adam", "adam8bit"]'
252 ) 252 )
253 parser.add_argument( 253 parser.add_argument(
254 "--adam_beta1", 254 "--adam_beta1",
@@ -419,7 +419,7 @@ def main():
419 save_args(output_dir, args) 419 save_args(output_dir, args)
420 420
421 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 421 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
422 args.pretrained_model_name_or_path, noise_scheduler="deis") 422 args.pretrained_model_name_or_path)
423 423
424 vae.enable_slicing() 424 vae.enable_slicing()
425 vae.set_use_memory_efficient_attention_xformers(True) 425 vae.set_use_memory_efficient_attention_xformers(True)
@@ -488,13 +488,6 @@ def main():
488 eps=args.adam_epsilon, 488 eps=args.adam_epsilon,
489 amsgrad=args.adam_amsgrad, 489 amsgrad=args.adam_amsgrad,
490 ) 490 )
491 elif args.optimizer == 'lion':
492 try:
493 from lion_pytorch import Lion
494 except ImportError:
495 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
496
497 create_optimizer = partial(Lion, use_triton=True)
498 else: 491 else:
499 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 492 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
500 493
diff --git a/train_ti.py b/train_ti.py
index 12e3644..6dc07dd 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -86,7 +86,7 @@ def parse_args():
86 help="Number of vectors per embedding." 86 help="Number of vectors per embedding."
87 ) 87 )
88 parser.add_argument( 88 parser.add_argument(
89 "--simultaneous", 89 "--sequential",
90 action="store_true", 90 action="store_true",
91 ) 91 )
92 parser.add_argument( 92 parser.add_argument(
@@ -293,7 +293,7 @@ def parse_args():
293 "--optimizer", 293 "--optimizer",
294 type=str, 294 type=str,
295 default="adam", 295 default="adam",
296 help='Optimizer to use ["adam", "adam8bit", "lion"]' 296 help='Optimizer to use ["adam", "adam8bit"]'
297 ) 297 )
298 parser.add_argument( 298 parser.add_argument(
299 "--adam_beta1", 299 "--adam_beta1",
@@ -343,6 +343,11 @@ def parse_args():
343 help="How often to save a checkpoint and sample image (in epochs)", 343 help="How often to save a checkpoint and sample image (in epochs)",
344 ) 344 )
345 parser.add_argument( 345 parser.add_argument(
346 "--no_milestone_checkpoints",
347 action='store_true',
348 help="If checkpoints are saved on maximum accuracy",
349 )
350 parser.add_argument(
346 "--sample_frequency", 351 "--sample_frequency",
347 type=int, 352 type=int,
348 default=1, 353 default=1,
@@ -480,7 +485,7 @@ def parse_args():
480 if len(args.placeholder_tokens) != len(args.num_vectors): 485 if len(args.placeholder_tokens) != len(args.num_vectors):
481 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 486 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
482 487
483 if not args.simultaneous: 488 if args.sequential:
484 if isinstance(args.train_data_template, str): 489 if isinstance(args.train_data_template, str):
485 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 490 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
486 491
@@ -586,13 +591,6 @@ def main():
586 eps=args.adam_epsilon, 591 eps=args.adam_epsilon,
587 amsgrad=args.adam_amsgrad, 592 amsgrad=args.adam_amsgrad,
588 ) 593 )
589 elif args.optimizer == 'lion':
590 try:
591 from lion_pytorch import Lion
592 except ImportError:
593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
594
595 create_optimizer = partial(Lion, use_triton=True)
596 else: 594 else:
597 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 595 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
598 596
@@ -615,6 +613,7 @@ def main():
615 num_train_epochs=args.num_train_epochs, 613 num_train_epochs=args.num_train_epochs,
616 sample_frequency=args.sample_frequency, 614 sample_frequency=args.sample_frequency,
617 checkpoint_frequency=args.checkpoint_frequency, 615 checkpoint_frequency=args.checkpoint_frequency,
616 milestone_checkpoints=not args.no_milestone_checkpoints,
618 global_step_offset=global_step_offset, 617 global_step_offset=global_step_offset,
619 # -- 618 # --
620 tokenizer=tokenizer, 619 tokenizer=tokenizer,
@@ -715,7 +714,7 @@ def main():
715 714
716 plot_metrics(metrics, metrics_output_file) 715 plot_metrics(metrics, metrics_output_file)
717 716
718 if args.simultaneous: 717 if not args.sequential:
719 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 718 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
720 else: 719 else:
721 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 720 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
diff --git a/training/functional.py b/training/functional.py
index 85dd884..739d055 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -362,6 +362,7 @@ def train_loop(
362 loss_step: LossCallable, 362 loss_step: LossCallable,
363 sample_frequency: int = 10, 363 sample_frequency: int = 10,
364 checkpoint_frequency: int = 50, 364 checkpoint_frequency: int = 50,
365 milestone_checkpoints: bool = True,
365 global_step_offset: int = 0, 366 global_step_offset: int = 0,
366 num_epochs: int = 100, 367 num_epochs: int = 100,
367 callbacks: TrainingCallbacks = TrainingCallbacks(), 368 callbacks: TrainingCallbacks = TrainingCallbacks(),
@@ -514,7 +515,7 @@ def train_loop(
514 accelerator.log(logs, step=global_step) 515 accelerator.log(logs, step=global_step)
515 516
516 if accelerator.is_main_process: 517 if accelerator.is_main_process:
517 if avg_acc_val.avg.item() > best_acc_val: 518 if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints:
518 local_progress_bar.clear() 519 local_progress_bar.clear()
519 global_progress_bar.clear() 520 global_progress_bar.clear()
520 521
@@ -527,7 +528,7 @@ def train_loop(
527 accs.append(avg_acc_val.avg.item()) 528 accs.append(avg_acc_val.avg.item())
528 else: 529 else:
529 if accelerator.is_main_process: 530 if accelerator.is_main_process:
530 if avg_acc.avg.item() > best_acc: 531 if avg_acc.avg.item() > best_acc and milestone_checkpoints:
531 local_progress_bar.clear() 532 local_progress_bar.clear()
532 global_progress_bar.clear() 533 global_progress_bar.clear()
533 534
@@ -572,6 +573,7 @@ def train(
572 num_train_epochs: int = 100, 573 num_train_epochs: int = 100,
573 sample_frequency: int = 20, 574 sample_frequency: int = 20,
574 checkpoint_frequency: int = 50, 575 checkpoint_frequency: int = 50,
576 milestone_checkpoints: bool = True,
575 global_step_offset: int = 0, 577 global_step_offset: int = 0,
576 with_prior_preservation: bool = False, 578 with_prior_preservation: bool = False,
577 prior_loss_weight: float = 1.0, 579 prior_loss_weight: float = 1.0,
@@ -626,6 +628,7 @@ def train(
626 loss_step=loss_step_, 628 loss_step=loss_step_,
627 sample_frequency=sample_frequency, 629 sample_frequency=sample_frequency,
628 checkpoint_frequency=checkpoint_frequency, 630 checkpoint_frequency=checkpoint_frequency,
631 milestone_checkpoints=milestone_checkpoints,
629 global_step_offset=global_step_offset, 632 global_step_offset=global_step_offset,
630 num_epochs=num_train_epochs, 633 num_epochs=num_train_epochs,
631 callbacks=callbacks, 634 callbacks=callbacks,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 66d3129..09beec4 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_before_optimize(lr: float, epoch: int): 117 def on_before_optimize(lr: float, epoch: int):
118 if use_emb_decay: 118 if use_emb_decay:
119 text_encoder.text_model.embeddings.normalize( 119 lambda_ = emb_decay * lr
120 emb_decay_target, 120
121 min(1.0, emb_decay * lr) 121 if lambda_ != 0:
122 ) 122 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
123
124 mask = torch.zeros(w.size(0), dtype=torch.bool)
125 mask[text_encoder.text_model.embeddings.temp_token_ids] = True
126 mask[torch.all(w.grad == 0, dim=1)] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
123 130
124 def on_after_optimize(lr: float): 131 def on_after_optimize(lr: float):
125 if ema_embeddings is not None: 132 if ema_embeddings is not None: